# Owner(s): ["module: nn"]
import itertools
import math
import os
import unittest
import warnings
from itertools import product

import torch
import torch.autograd.forward_ad as fwAD
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.nn.functional as F
from torch.testing import make_tensor


def _get_cudnn_version():
    """Safely get cuDNN version, returning None if unavailable."""
    try:
        return torch.backends.cudnn.version()
    except RuntimeError:
        return None


from torch.testing._internal.common_cuda import TEST_CUDA, TEST_CUDNN, tf32_on_and_off
from torch.testing._internal.common_device_type import (
    disablecuDNN,
    disableMkldnn,
    dtypes,
    dtypesIfCUDA,
    dtypesIfMPS,
    expectedFailureMPS,
    instantiate_device_type_tests,
    largeTensorTest,
    onlyCPU,
    onlyCUDA,
    onlyNativeDeviceTypes,
    precisionOverride,
    skipCPUIfNoMkldnn,
    skipCUDAIfMiopen,
    skipCUDAIfNoCudnn,
    skipCUDAIfNoMiopen,
    skipCUDAIfRocm,
    skipMeta,
    skipMPS,
)
from torch.testing._internal.common_dtype import (
    floating_and_complex_types_and,
    floating_types_and,
)
from torch.testing._internal.common_nn import _test_module_empty_input, NNTestCase
from torch.testing._internal.common_utils import (
    download_file,
    dtype2prec_DONTUSE,
    gradcheck,
    GRADCHECK_NONDET_TOL,
    gradgradcheck,
    instantiate_parametrized_tests,
    MACOS_VERSION,
    MI300_ARCH,
    parametrize as parametrize_test,
    run_tests,
    serialTest,
    set_default_dtype,
    skipIfRocmArch,
    subtest,
    TEST_SCIPY,
    TEST_WITH_ROCM,
    xfailIf,
)


AMPERE_OR_ROCM = TEST_WITH_ROCM or torch.cuda.is_tf32_supported()


if TEST_WITH_ROCM:
    os.environ["PYTORCH_MIOPEN_SUGGEST_NHWC"] = "1"
    os.environ["PYTORCH_MIOPEN_SUGGEST_NHWC_BATCHNORM"] = "1"


if TEST_SCIPY:
    import scipy.ndimage
    import scipy.signal


class TestConvolutionNN(NNTestCase):
    _do_cuda_memory_leak_check = True
    _do_cuda_non_default_stream = True

    def test_conv_backcompat(self):
        from torch.serialization import SourceChangeWarning

        # This file was generated by running on PyTorch 1.0.1 on Python 2:
        #
        #     import torch
        #     from torch import nn
        #     m = nn.Conv2d(1, 1, 1)
        #     torch.save(m, 'legacy_conv2d.pt')
        #
        # NB: This Pickle also contains some Unicode data!
        path = download_file("https://download.pytorch.org/test_data/legacy_conv2d.pt")
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", SourceChangeWarning)
            # weights_only=False as this is legacy code that saves the model
            m = torch.load(path, encoding="utf-8", weights_only=False)
        input = torch.randn((1, 1, 1, 1), dtype=torch.float)
        self.assertEqual(m(input).size(), (1, 1, 1, 1))

    def test_huge_padding(self):
        class Conv1dModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.conv1 = nn.Conv1d(
                    in_channels=16,
                    out_channels=32,
                    kernel_size=3,
                    stride=1,
                    padding=9223372036854775803,
                )
                self.add_module(name="conv1", module=self.conv1)

        input_data = torch.randn(1, 16, 100)
        model = Conv1dModule()
        with self.assertRaisesRegex(
            RuntimeError,
            r"Given padding=9223372036854775803 at dimension 0 , expected padding to be at most",
        ):
            model.conv1(input_data)

        class ConvTransposed1dModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.conv_transposed1d = nn.ConvTranspose1d(
                    in_channels=16,
                    out_channels=32,
                    kernel_size=3,
                    stride=2,
                    padding=9223372036854775803,
                )
                self.add_module(name="conv_transposed1d", module=self.conv_transposed1d)

        input_data = torch.randn(1, 16, 100)
        model = ConvTransposed1dModule()
        with self.assertRaisesRegex(
            RuntimeError,
            r"Given padding=9223372036854775803 at dimension 0 , expected padding to be at most",
        ):
            model.conv_transposed1d(input_data)

    def test_invalid_conv1d(self):
        for dtype in [
            torch.half,
            torch.bfloat16,
            torch.float,
            torch.double,
            torch.cfloat,
            torch.cdouble,
        ]:
            module = nn.Conv1d(
                in_channels=3, out_channels=33, kernel_size=10, stride=1, bias=True
            ).to(dtype)
            input = torch.randn(1, 3, 4).to(dtype)
            with self.assertRaisesRegex(
                RuntimeError,
                r"Calculated padded input size per channel: \(4\). "
                + r"Kernel size: \(10\). Kernel size can\'t be greater than actual input size",
            ):
                module(input)

            # Negative stride check
            module = nn.Conv1d(
                in_channels=3, out_channels=6, kernel_size=3, stride=-1, bias=True
            ).to(dtype)
            input = torch.randn(1, 3, 4).to(dtype)
            with self.assertRaisesRegex(
                RuntimeError, "non-positive stride is not supported"
            ):
                module(input)

    def test_mismatch_shape_conv2d(self):
        for dtype in (torch.float, torch.cfloat):
            x = torch.randn(1, 10, 1, 28, 28, dtype=dtype)
            w = torch.randn(6, 1, 5, 5, dtype=dtype)

            with self.assertRaisesRegex(
                RuntimeError,
                r"Expected 3D \(unbatched\) or 4D \(batched\) input to conv2d, but got "
                + r"input of size: \[1, 10, 1, 28, 28\]",
            ):
                F.conv2d(x, w)

    def test_conv2d_discontiguous_weight(self):
        for dtype in (torch.float, torch.cfloat):
            # Test for https://github.com/pytorch/pytorch/issues/55781
            x = torch.ones(64, 16, 16, 16, dtype=dtype)
            weight = (
                torch.arange(0, 1.0, 1 / 2.0**10)
                .reshape(32, 16, 1, 2)
                .to(dtype)[:, :, :, ::2]
            )
            self.assertFalse(weight.is_contiguous())
            y = torch.nn.functional.conv2d(x, weight, None)
            if torch.backends.mkldnn.is_available():
                # Disable MKLDNN explicitly, so that either NNPACK or THCNN will be used
                with torch.backends.mkldnn.flags(enabled=False):
                    y_ = torch.nn.functional.conv2d(x, weight, None)
                    self.assertEqual(y, y_)
            self.assertEqual(y.sum(), 4186112.0)

    def test_invalid_conv2d(self):
        for dtype in [
            torch.half,
            torch.bfloat16,
            torch.float,
            torch.double,
            torch.cfloat,
            torch.cdouble,
        ]:
            module = torch.nn.Conv2d(1, 1, kernel_size=3, dilation=2, stride=2).to(
                dtype
            )
            input = torch.empty(1, 1, 4, 4).to(dtype)
            self.assertRaises(RuntimeError, lambda: module(input))

            module = nn.Conv2d(
                in_channels=3, out_channels=33, kernel_size=10, stride=1, bias=True
            )
            input = torch.randn(1, 3, 1, 1)
            with self.assertRaisesRegex(
                RuntimeError,
                r"Calculated padded input size per channel: \(1 x 1\). "
                + r"Kernel size: \(10 x 10\). Kernel size can\'t be greater than actual input size",
            ):
                module(input)

            # Negative stride check
            module = nn.Conv2d(
                in_channels=3, out_channels=6, kernel_size=4, stride=-1, bias=True
            ).to(dtype)
            input = torch.randn(1, 3, 4, 4).to(dtype)
            with self.assertRaisesRegex(
                RuntimeError, "non-positive stride is not supported"
            ):
                module(input)

            # Zero stride check
            module = nn.Conv2d(
                in_channels=3, out_channels=6, kernel_size=4, stride=0, bias=True
            ).to(dtype)
            input = torch.randn(1, 3, 4, 4).to(dtype)
            with self.assertRaisesRegex(
                RuntimeError, "non-positive stride is not supported"
            ):
                module(input)

    def test_invalid_conv3d(self):
        for dtype in [
            torch.half,
            torch.bfloat16,
            torch.float,
            torch.double,
            torch.cfloat,
            torch.cdouble,
        ]:
            module = torch.nn.Conv3d(1, 1, kernel_size=3, dilation=2, stride=2).to(
                dtype
            )
            input = torch.empty(1, 1, 4, 4, 4).to(dtype)
            self.assertRaises(RuntimeError, lambda: module(input))

            # Negative stride check
            module = torch.nn.Conv3d(1, 1, kernel_size=3, stride=-2)
            input = torch.empty(1, 1, 4, 4, 4)
            with self.assertRaisesRegex(
                RuntimeError, "non-positive stride is not supported"
            ):
                module(input)

    def test_conv_invalid_groups(self):
        with self.assertRaisesRegex(ValueError, "groups must be a positive integer"):
            torch.nn.Conv1d(1, 1, kernel_size=3, dilation=2, stride=2, groups=0)
        with self.assertRaisesRegex(ValueError, "groups must be a positive integer"):
            torch.nn.Conv2d(1, 1, kernel_size=3, dilation=2, stride=2, groups=-1)
        with self.assertRaisesRegex(ValueError, "groups must be a positive integer"):
            torch.nn.Conv3d(1, 1, kernel_size=3, dilation=2, stride=2, groups=-2)

    def test_conv_aten_invalid_groups(self):
        # test low-level aten ops with invalid groups parameter
        grad_output = torch.randn(2, 4, 8, dtype=torch.double)
        input = torch.randn(2, 5, 8, dtype=torch.double)
        weight = torch.randn(5, 4, 3, dtype=torch.double)
        bias_sizes = [4]
        stride = [1]
        padding = [1]
        dilation = [1]
        transposed = True
        output_padding = [0]
        output_mask = [True, True, True]

        # test groups=0
        with self.assertRaisesRegex(
            RuntimeError, "expected groups to be greater than 0, but got groups=0"
        ):
            torch.ops.aten.convolution_backward(
                grad_output,
                input,
                weight,
                bias_sizes,
                stride,
                padding,
                dilation,
                transposed,
                output_padding,
                0,
                output_mask,
            )

        # test groups=-1
        with self.assertRaisesRegex(
            RuntimeError, "expected groups to be greater than 0, but got groups=-1"
        ):
            torch.ops.aten.convolution_backward(
                grad_output,
                input,
                weight,
                bias_sizes,
                stride,
                padding,
                dilation,
                transposed,
                output_padding,
                -1,
                output_mask,
            )

    def test_conv3d_overflow_values(self):
        input = torch.full(
            (
                0,
                7,
                9,
                1,
                5,
            ),
            0,
            dtype=torch.float32,
            requires_grad=False,
        )
        weight = torch.full(
            (
                9,
                1,
            ),
            4.14214e16,
            dtype=torch.float32,
            requires_grad=False,
        )
        stride = [5, 5, 5]

        with self.assertRaisesRegex(ValueError, "Padding height too large"):
            torch.ops.aten.slow_conv3d(
                input,
                weight,
                kernel_size=[5, 5, 5],
                bias=None,
                stride=stride,
                padding=[2**62, 2**62, 2**62],
            )

        with self.assertRaisesRegex(
            RuntimeError, "Kernel height x width product is too large:"
        ):
            torch.ops.aten.slow_conv3d(
                input,
                weight,
                kernel_size=[2**32, 2**32, 2**32],
                bias=None,
                stride=stride,
                padding=[2**31, 2**31, 2**31],
            )

    def test_Conv1d_module_same_padding(self):
        # Compare module against functional: without strides/dilation, asymmetric padding
        x = torch.rand(1, 1, 20)
        module = nn.Conv1d(
            in_channels=1, out_channels=1, kernel_size=10, padding="same"
        )
        expect = F.conv1d(x, module.weight, module.bias, padding="same")
        self.assertEqual(expect, module(x))

        # Test dilation, symmetric padding
        module = nn.Conv1d(
            in_channels=1, out_channels=1, kernel_size=10, padding="same", dilation=2
        )
        expect = F.conv1d(x, module.weight, module.bias, padding="same", dilation=2)
        self.assertEqual(expect, module(x))

        # Test non-zero padding_mode, requiring explicit padding
        module = nn.Conv1d(
            in_channels=1,
            out_channels=1,
            kernel_size=10,
            padding="same",
            padding_mode="replicate",
        )
        x_padded = F.pad(x, [4, 5], mode="replicate")
        expect = F.conv1d(x_padded, module.weight, module.bias, padding="valid")
        self.assertEqual(expect, module(x))
        self.assertEqual(x.size(), expect.size())

        # Test connstruction with invalid padding string raises
        with self.assertRaisesRegex(ValueError, "Invalid padding string"):
            module = nn.Conv1d(
                in_channels=3, out_channels=33, kernel_size=10, padding="foo"
            )

        # Test connstruction with same padding and strides raises
        with self.assertRaisesRegex(ValueError, "padding='same'"):
            module = nn.Conv1d(
                in_channels=3, out_channels=33, kernel_size=10, padding="same", stride=2
            )

    def test_Conv2d_module_same_padding(self):
        # Compare module against functional:
        # without strides/dilation, both symmetric and asymmetric padding
        x = torch.rand(1, 1, 9, 20)
        module = nn.Conv2d(
            in_channels=1, out_channels=1, kernel_size=(5, 10), padding="same"
        )
        expect = F.conv2d(x, module.weight, module.bias, padding="same")
        self.assertEqual(expect, module(x))

        # with dilation, symmetric padding
        module = nn.Conv2d(
            in_channels=1,
            out_channels=1,
            kernel_size=(3, 4),
            padding="same",
            dilation=(1, 2),
        )
        expect = F.conv2d(
            x, module.weight, module.bias, padding="same", dilation=(1, 2)
        )
        self.assertEqual(expect, module(x))

        # Test non-zero padding_mode, requiring explicit padding
        module = nn.Conv2d(
            in_channels=1,
            out_channels=1,
            kernel_size=(3, 4),
            padding="same",
            padding_mode="reflect",
        )
        x_padded = F.pad(x, [1, 2, 1, 1], mode="reflect")
        expect = F.conv2d(x_padded, module.weight, module.bias, padding="valid")
        self.assertEqual(expect, module(x))
        self.assertEqual(x.size(), expect.size())

        # Test connstruction with invalid padding string raises
        with self.assertRaisesRegex(ValueError, "Invalid padding string"):
            module = nn.Conv2d(
                in_channels=3, out_channels=33, kernel_size=10, padding="foo"
            )

        # Test connstruction with same padding and strides raises
        with self.assertRaisesRegex(ValueError, "padding='same'"):
            module = nn.Conv2d(
                in_channels=3, out_channels=33, kernel_size=10, padding="same", stride=2
            )
        with self.assertRaisesRegex(ValueError, "padding='same'"):
            module = nn.Conv2d(
                in_channels=3,
                out_channels=33,
                kernel_size=10,
                padding="same",
                stride=(1, 3),
            )
        with self.assertRaisesRegex(ValueError, "padding='same'"):
            module = nn.Conv2d(
                in_channels=3,
                out_channels=33,
                kernel_size=10,
                padding="same",
                stride=(4, 1),
            )

    def test_Conv3d_module_same_padding(self):
        # Compare module against functional:
        x = torch.rand(1, 1, 4, 4, 4)
        # without dilation, both symmetric and asymmetric padding
        module = nn.Conv3d(
            in_channels=1, out_channels=1, kernel_size=(2, 3, 4), padding="same"
        )
        expect = F.conv3d(x, module.weight, module.bias, padding="same")
        self.assertEqual(expect, module(x))

        # with dilation, both symmetric and asymmetric padding
        module = nn.Conv3d(
            in_channels=1,
            out_channels=1,
            kernel_size=(2, 3, 4),
            padding="same",
            dilation=(3, 2, 1),
        )
        expect = F.conv3d(
            x, module.weight, module.bias, padding="same", dilation=(3, 2, 1)
        )
        self.assertEqual(expect, module(x))

        # Test non-zero padding_mode, requiring explicit padding
        module = nn.Conv3d(
            in_channels=1,
            out_channels=1,
            kernel_size=(2, 3, 4),
            padding="same",
            padding_mode="circular",
        )
        x_padded = F.pad(x, [1, 2, 1, 1, 0, 1], mode="circular")
        expect = F.conv3d(x_padded, module.weight, module.bias, padding="valid")
        self.assertEqual(expect, module(x))
        self.assertEqual(x.size(), expect.size())

        # Test connstruction with invalid padding string raises
        with self.assertRaisesRegex(ValueError, "Invalid padding string"):
            module = nn.Conv3d(
                in_channels=3, out_channels=33, kernel_size=10, padding="foo"
            )

        # Test connstruction with same padding and strides raises
        with self.assertRaisesRegex(ValueError, "padding='same'"):
            module = nn.Conv3d(
                in_channels=3, out_channels=33, kernel_size=10, padding="same", stride=2
            )
        with self.assertRaisesRegex(ValueError, "padding='same'"):
            module = nn.Conv3d(
                in_channels=3,
                out_channels=33,
                kernel_size=10,
                padding="same",
                stride=(1, 1, 3),
            )
        with self.assertRaisesRegex(ValueError, "padding='same'"):
            module = nn.Conv3d(
                in_channels=3,
                out_channels=33,
                kernel_size=10,
                padding="same",
                stride=(1, 4, 1),
            )
        with self.assertRaisesRegex(ValueError, "padding='same'"):
            module = nn.Conv3d(
                in_channels=3,
                out_channels=33,
                kernel_size=10,
                padding="same",
                stride=(5, 1, 1),
            )

    @unittest.skipIf(not TEST_CUDA, "CUDA not available")
    def test_thnn_conv_strided_padded_dilated(self):
        for convfn, dims, transposed in (
            (torch.nn.functional.conv2d, 2, False),
            (torch.nn.functional.conv_transpose2d, 2, True),
            (torch.nn.functional.conv3d, 3, False),
            (torch.nn.functional.conv_transpose3d, 3, True),
        ):
            for stride, padding, dilation in (
                (2, 0, 1),
                (1, 1, 1),
                (2, 1, 1),
                (1, 0, 2),
            ):
                kwargs = {"stride": stride, "padding": padding, "dilation": dilation}
                inp_shape = (1, 2) + dims * (4,)
                weight_shape = (2, 2) + dims * (1,)
                inputs = torch.randn(
                    inp_shape, dtype=torch.double, device="cuda", requires_grad=True
                )
                weight = torch.randn(
                    weight_shape, dtype=torch.double, device="cuda", requires_grad=True
                )
                bias = torch.randn(
                    2, dtype=torch.double, device="cuda", requires_grad=True
                )
                with torch.backends.cudnn.flags(enabled=False):
                    res = convfn(inputs, weight, bias, **kwargs)
                res_cpu = convfn(inputs.cpu(), weight.cpu(), bias.cpu(), **kwargs)
                self.assertEqual(res, res_cpu)
                with torch.backends.cudnn.flags(enabled=False):
                    torch.autograd.gradcheck(
                        lambda x, w, b: convfn(x, w, b, **kwargs),
                        (inputs, weight, bias),
                    )
                    torch.autograd.gradcheck(
                        lambda x, w, b: convfn(x, w, b, **kwargs),
                        (inputs.cpu(), weight.cpu(), bias.cpu()),
                    )

    def test_Conv2d_inconsistent_types(self):
        inputs = torch.randn(4, 1, 7, 7, dtype=torch.float)
        weights = torch.randn(1, 1, 3, 3, dtype=torch.double)
        # inconsistent types should raise an exception
        self.assertRaises(RuntimeError, lambda: nn.functional.conv2d(inputs, weights))
        # but it should work with the same type
        nn.functional.conv2d(inputs.float(), weights.float())

    @unittest.skipIf(not TEST_CUDA, "CUDA not available")
    def test_Conv2d_inconsistent_types_on_GPU_without_cudnn(self):
        inputs = torch.randn(4, 1, 7, 7, dtype=torch.float, device="cuda")
        weights = torch.randn(1, 1, 3, 3, dtype=torch.double, device="cuda")
        bias = torch.randn(1, dtype=torch.double, device="cuda")

        with torch.backends.cudnn.flags(enabled=False):
            # inconsistent types should raise an exception
            self.assertRaises(
                RuntimeError, lambda: nn.functional.conv2d(inputs, weights)
            )
            self.assertRaises(
                RuntimeError,
                lambda: nn.functional.conv2d(inputs, weights.float(), bias),
            )

            # but it should work with the same type
            nn.functional.conv2d(inputs.float(), weights.float(), bias.float())

    def test_Conv2d_1x1(self):
        in_channels = 2
        mod = torch.nn.Conv2d(2, 2, 1, bias=False).to(dtype=torch.double)
        input = torch.randn(
            1, in_channels, 5, 5, requires_grad=True, dtype=torch.double
        )
        for enabled in (False, True):
            with torch.backends.mkldnn.flags(enabled=enabled):
                gradcheck(F.conv2d, (input, mod.weight))

    def test_Conv2d_OneDNN(self):
        def run_once(group_val=24, dilation=1):
            ifm = torch.ones([1, group_val, 6, 6], dtype=torch.float32)
            weights = torch.ones([group_val, 1, 3, 3], dtype=torch.float32)
            op = torch.nn.Conv2d(
                in_channels=group_val,
                out_channels=group_val,
                kernel_size=[3, 3],
                stride=[2, 2],
                padding=[1, 1],
                dilation=[dilation, dilation],
                groups=group_val,
                bias=False,
                padding_mode="zeros",
            )

            op.weight.data = weights
            res = op(ifm)
            grad_in = torch.ones(res.shape, dtype=torch.float32)
            res.backward(grad_in)
            return op.weight.grad

        for gorup_val in (24, 48, 23, 25):
            for dilation in (1, 2):
                with torch.backends.mkldnn.flags(enabled=False):
                    without_onednn = run_once(gorup_val, dilation)

                with torch.backends.mkldnn.flags(enabled=True):
                    with_onednn = run_once(gorup_val, dilation)

                self.assertEqual(without_onednn, with_onednn)

    @unittest.skipIf(not TEST_CUDA, "CUDA not available")
    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
    def test_cudnn_non_contiguous(self):
        x = torch.randn(192, 16, 50).cuda()
        x = x.permute(0, 2, 1).contiguous().permute(0, 2, 1)
        m = torch.nn.Conv1d(
            in_channels=16, out_channels=32, kernel_size=2, bias=True
        ).cuda()
        m(x)

    @unittest.skipIf(not TEST_CUDA, "CUDA not available")
    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
    def test_cudnn_not_mutate_stride(self):
        weight = torch.randn(64, 64, 1, 1)
        x = torch.randn(2, 64, 10, 10).to(memory_format=torch.channels_last)
        weight_stride = weight.stride()

        def conv(x, weight):
            return torch.convolution(
                x,
                weight,
                stride=(1, 1),
                padding=(0, 0),
                dilation=(1, 1),
                transposed=False,
                output_padding=(0, 0),
                groups=1,
                bias=None,
            )

        # should have run in nhwc without mutating input strides
        out_nhwc = conv(x, weight)
        self.assertEqual(weight.stride(), weight_stride)
        self.assertTrue(out_nhwc.is_contiguous(memory_format=torch.channels_last))

        x = x.contiguous(memory_format=torch.contiguous_format)
        out_c = conv(x, weight)
        self.assertTrue(out_c.is_contiguous(memory_format=torch.contiguous_format))
        self.assertEqual(out_c, out_nhwc)
        self.assertEqual(weight.stride(), weight_stride)

    @unittest.skipIf(not TEST_CUDA, "CUDA not available")
    @unittest.skipIf(not TEST_CUDNN, "CUDNN not available")
    def test_Conv2d_inconsistent_types_on_GPU_with_cudnn(self):
        inputs = torch.randn(4, 1, 7, 7, dtype=torch.float, device="cuda")
        weights = torch.randn(1, 1, 3, 3, dtype=torch.double, device="cuda")
        bias = torch.randn(1, dtype=torch.double, device="cuda")

        with torch.backends.cudnn.flags(enabled=True):
            # inconsistent types should raise an exception
            self.assertRaises(
                RuntimeError, lambda: nn.functional.conv2d(inputs, weights)
            )
            self.assertRaises(
                RuntimeError,
                lambda: nn.functional.conv2d(inputs, weights.float(), bias),
            )

            # but it should work with the same type
            nn.functional.conv2d(inputs.float(), weights.float(), bias.float())

    def test_Conv2d_missing_argument(self):
        c = nn.Conv2d(3, 3, 3)
        self.assertRaises(TypeError, lambda: c(None))

    def test_Conv2d_backward_twice(self):
        input = torch.randn(2, 3, 5, 5)
        c = nn.Conv2d(3, 3, 3)
        o1 = c(input)
        o1.sum().backward()
        self.assertRaisesRegex(
            RuntimeError, "Specify retain_graph=True", lambda: o1.sum().backward()
        )

    def test_conv_modules_raise_error_on_incorrect_input_size(self):
        for dtype in [torch.half, torch.bfloat16, torch.double, torch.float]:
            modules = [
                nn.Conv1d(3, 8, 3).to(dtype),
                nn.ConvTranspose1d(3, 8, 3).to(dtype),
                nn.Conv2d(3, 8, 3).to(dtype),
                nn.ConvTranspose2d(3, 8, 3).to(dtype),
                nn.Conv3d(3, 8, 3).to(dtype),
                nn.ConvTranspose3d(3, 8, 3).to(dtype),
            ]

            invalid_input_dims = [(1, 4), (1, 4), (2, 5), (2, 5), (3, 6), (3, 6)]

            for invalid_dims, module in zip(invalid_input_dims, modules):
                for dims in invalid_dims:
                    input = torch.empty(torch.Size((3,) * dims))
                    self.assertRaises(RuntimeError, lambda: module(input))

    def test_conv_shapecheck(self):
        def test(should_raise, module, input_size, dtype):
            input = torch.empty(3, *input_size).to(dtype)
            if should_raise:
                self.assertRaises(RuntimeError, lambda: module(input))
            else:
                # just run it to ensure no exception raised.
                module(input)

        for dtype in [
            torch.half,
            torch.bfloat16,
            torch.float,
            torch.double,
            torch.cfloat,
            torch.cdouble,
        ]:
            # Conv1d
            test(True, nn.Conv1d(1, 1, 3).to(dtype), (1, 2), dtype)
            test(True, nn.Conv1d(1, 1, 3, stride=2).to(dtype), (1, 2), dtype)
            test(False, nn.Conv1d(1, 1, 2).to(dtype), (1, 2), dtype)
            test(False, nn.Conv1d(1, 1, 2, stride=2).to(dtype), (1, 2), dtype)
            test(
                False, nn.Conv1d(1, 1, 3, stride=2, padding=1).to(dtype), (1, 2), dtype
            )

            # Conv2d
            test(True, nn.Conv2d(1, 1, (3, 3)).to(dtype), (1, 2, 2), dtype)
            test(False, nn.Conv2d(1, 1, (3, 3)).to(dtype), (1, 3, 3), dtype)
            test(False, nn.Conv2d(1, 1, (3, 3), padding=1).to(dtype), (1, 2, 2), dtype)

            # Conv3D
            test(True, nn.Conv3d(1, 1, (3, 3, 3)).to(dtype), (1, 2, 2, 2), dtype)
            test(False, nn.Conv3d(1, 1, (3, 3, 3)).to(dtype), (1, 3, 3, 3), dtype)
            test(
                False,
                nn.Conv3d(1, 1, (3, 3, 3), padding=1).to(dtype),
                (1, 2, 2, 2),
                dtype,
            )

    def test_ConvTranspose2d_output_size(self):
        m = nn.ConvTranspose2d(3, 4, 3, 3, 0, 2)
        i = torch.randn(2, 3, 6, 6)
        for h in range(15, 22):
            for w in range(15, 22):
                if 18 <= h <= 20 and 18 <= w <= 20:
                    output = m(i, output_size=(h, w))
                    self.assertEqual(output.size()[2:], (h, w))
                else:
                    self.assertRaises(ValueError, lambda: m(i, (h, w)))

    def test_ConvTranspose2d_output_size_downsample_upsample(self):
        b, c, hid_c = 2, 3, 2
        for h in range(13, 24):
            for w in range(13, 17):
                for k in range(2, 5):
                    for d in range(1, 5):
                        for s in range(1, 4):
                            for p in range(3):
                                conv = nn.Conv2d(
                                    in_channels=c,
                                    out_channels=hid_c,
                                    kernel_size=k,
                                    stride=s,
                                    padding=p,
                                    dilation=d,
                                )

                                t_conv = nn.ConvTranspose2d(
                                    in_channels=hid_c,
                                    out_channels=c,
                                    kernel_size=k,
                                    stride=s,
                                    padding=p,
                                    dilation=d,
                                )

                                i = torch.randn(b, c, h, w)

                                out = t_conv(conv(i), output_size=i.shape)

                                self.assertEqual(out.size()[2:], i.size()[2:])

    def test_ConvTranspose3d_correct_output_size(self):
        # Check that ConvTranspose3d can take a 5d output_size.
        m = nn.ConvTranspose3d(2, 2, 2)
        i = torch.rand(1, 2, 1, 1, 1)
        m(i, output_size=(1, 2, 2, 2, 2))

    @unittest.skipIf(not TEST_CUDA, "CUDA not available")
    def test_ConvTranspose2d_half_cublas_gemm(self):
        with torch.backends.cudnn.flags(enabled=False):
            inputs = torch.randn(1, 1, 16, 16, device="cuda", dtype=torch.half)
            deconv = (
                nn.ConvTranspose2d(1, 1, 3, stride=2, padding=1, output_padding=1)
                .cuda()
                .half()
            )
            output = deconv(inputs)
            output.mean().backward()

    # For https://github.com/pytorch/pytorch/pull/1273
    # Almost identical to the above `test_Conv2d_naive_groups`
    @torch.backends.cudnn.flags(enabled=True, deterministic=True, benchmark=False)
    @torch.backends.miopen.flags(immediate=True)
    @tf32_on_and_off(0.001)
    def test_Conv2d_groups_nobias(self):
        dev_dtypes = [("cpu", torch.float)]
        if TEST_CUDA:
            dev_dtypes += [("cuda", torch.float), ("cuda", torch.half)]
        if AMPERE_OR_ROCM:
            dev_dtypes += [("cuda", torch.bfloat16)]
        for device, dtype in dev_dtypes:
            m = nn.Conv2d(4, 4, kernel_size=3, groups=2, bias=False).to(device, dtype)
            i = torch.randn(2, 4, 6, 6, device=device, dtype=dtype, requires_grad=True)
            output = m(i)
            grad_output = torch.randn(2, 4, 4, 4, device=device, dtype=dtype)
            output.backward(grad_output)

            m1 = nn.Conv2d(2, 2, kernel_size=3, bias=False).to(device, dtype)
            m1.weight.data.copy_(m.weight.data[:2])
            i1 = i.data[:, :2].contiguous().requires_grad_(True)
            output1 = m1(i1)
            output1.backward(grad_output[:, :2].contiguous())

            m2 = nn.Conv2d(2, 2, kernel_size=3, bias=False).to(device, dtype)
            m2.weight.data.copy_(m.weight.data[2:])
            i2 = i.data[:, 2:].contiguous().requires_grad_(True)
            output2 = m2(i2)
            output2.backward(grad_output[:, 2:].contiguous())

            self.assertEqual(output, torch.cat([output1, output2], 1))
            self.assertEqual(
                i.grad.data,
                torch.cat([i1.grad.data, i2.grad.data], 1),
                atol=dtype2prec_DONTUSE[dtype],
                rtol=0,
            )
            self.assertEqual(
                m.weight.grad.data,
                torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0),
                atol=1e-1 if dtype == torch.half else dtype2prec_DONTUSE[dtype],
                rtol=0,
            )

    # Almost identical to the above `test_Conv2d_naive_groups`
    # Covering special case when group > 1, input-channel / group < 16 and output-channel is multiple of 16
    # See also https://github.com/pytorch/pytorch/pull/18463#issuecomment-476563686
    # and https://github.com/pytorch/pytorch/pull/18463#issuecomment-477001024
    @torch.backends.cudnn.flags(enabled=True, deterministic=True, benchmark=False)
    @torch.backends.miopen.flags(immediate=True)
    @tf32_on_and_off(0.001)
    def test_Conv2d_groups_nobias_v2(self):
        torch.manual_seed(123)
        dev_dtypes = [("cpu", torch.float)]
        if TEST_CUDA:
            dev_dtypes += [("cuda", torch.float), ("cuda", torch.half)]
        if AMPERE_OR_ROCM:
            dev_dtypes += [("cuda", torch.bfloat16)]
        for device, dtype in dev_dtypes:
            m = nn.Conv2d(4, 16, kernel_size=3, groups=2, bias=False).to(device, dtype)
            i = torch.randn(2, 4, 6, 6, device=device, dtype=dtype, requires_grad=True)
            output = m(i)
            grad_output = torch.randn(2, 16, 4, 4, device=device, dtype=dtype)
            output.backward(grad_output)

            m1 = nn.Conv2d(2, 8, kernel_size=3, bias=False).to(device, dtype)
            m1.weight.data.copy_(m.weight.data[:8])
            i1 = i.data[:, :2].contiguous().requires_grad_(True)
            output1 = m1(i1)
            output1.backward(grad_output[:, :8].contiguous())

            m2 = nn.Conv2d(2, 8, kernel_size=3, bias=False).to(device, dtype)
            m2.weight.data.copy_(m.weight.data[8:])
            i2 = i.data[:, 2:].contiguous().requires_grad_(True)
            output2 = m2(i2)
            output2.backward(grad_output[:, 8:].contiguous())

            self.assertEqual(output, torch.cat([output1, output2], 1))
            self.assertEqual(
                i.grad.data,
                torch.cat([i1.grad.data, i2.grad.data], 1),
                atol=dtype2prec_DONTUSE[dtype],
                rtol=0,
            )
            self.assertEqual(
                m.weight.grad.data,
                torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0),
                atol=1e-1 if dtype == torch.half else dtype2prec_DONTUSE[dtype],
                rtol=0,
            )

    # CPU-only test for group conv3d fast implementation using bmm
    # See: https://github.com/pytorch/pytorch/pull/36355
    def test_Conv3d_groups_nobias(self):
        torch.manual_seed(123)
        m = nn.Conv3d(4, 16, kernel_size=3, groups=2, bias=False).to("cpu", torch.float)
        i = torch.randn(
            2, 4, 6, 6, 6, device="cpu", dtype=torch.float, requires_grad=True
        )
        output = m(i)
        grad_output = torch.randn(2, 16, 4, 4, 4, device="cpu", dtype=torch.float)
        output.backward(grad_output)

        m1 = nn.Conv3d(2, 8, kernel_size=3, bias=False).to("cpu", torch.float)
        m1.weight.data.copy_(m.weight.data[:8])
        i1 = i.data[:, :2].contiguous().requires_grad_(True)
        output1 = m1(i1)
        output1.backward(grad_output[:, :8].contiguous())

        m2 = nn.Conv3d(2, 8, kernel_size=3, bias=False).to("cpu", torch.float)
        m2.weight.data.copy_(m.weight.data[8:])
        i2 = i.data[:, 2:].contiguous().requires_grad_(True)
        output2 = m2(i2)
        output2.backward(grad_output[:, 8:].contiguous())

        self.assertEqual(output, torch.cat([output1, output2], 1))
        self.assertEqual(
            i.grad.data,
            torch.cat([i1.grad.data, i2.grad.data], 1),
            atol=dtype2prec_DONTUSE[torch.float],
            rtol=0,
        )
        self.assertEqual(
            m.weight.grad.data,
            torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0),
            atol=dtype2prec_DONTUSE[torch.float],
            rtol=dtype2prec_DONTUSE[torch.float],
        )

    def test_Conv3d_groups_wbias(self):
        torch.manual_seed(123)
        m = nn.Conv3d(4, 16, kernel_size=3, groups=2, bias=True).to("cpu", torch.float)
        i = torch.randn(
            2, 4, 6, 6, 6, device="cpu", dtype=torch.float, requires_grad=True
        )
        output = m(i)
        grad_output = torch.randn(2, 16, 4, 4, 4, device="cpu", dtype=torch.float)
        output.backward(grad_output)

        m1 = nn.Conv3d(2, 8, kernel_size=3, bias=True).to("cpu", torch.float)
        m1.weight.data.copy_(m.weight.data[:8])
        m1.bias.data.copy_(m.bias.data[:8])
        i1 = i.data[:, :2].contiguous().requires_grad_(True)
        output1 = m1(i1)
        output1.backward(grad_output[:, :8].contiguous())

        m2 = nn.Conv3d(2, 8, kernel_size=3, bias=True).to("cpu", torch.float)
        m2.weight.data.copy_(m.weight.data[8:])
        m2.bias.data.copy_(m.bias.data[8:])
        i2 = i.data[:, 2:].contiguous().requires_grad_(True)
        output2 = m2(i2)
        output2.backward(grad_output[:, 8:].contiguous())

        self.assertEqual(output, torch.cat([output1, output2], 1))
        self.assertEqual(
            i.grad.data,
            torch.cat([i1.grad.data, i2.grad.data], 1),
            atol=dtype2prec_DONTUSE[torch.float],
            rtol=dtype2prec_DONTUSE[torch.float],
        )
        self.assertEqual(
            m.weight.grad.data,
            torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0),
            atol=dtype2prec_DONTUSE[torch.float],
            rtol=dtype2prec_DONTUSE[torch.float],
        )
        self.assertEqual(
            m.bias.grad.data,
            torch.cat([m1.bias.grad.data, m2.bias.grad.data], 0),
            atol=dtype2prec_DONTUSE[torch.float],
            rtol=dtype2prec_DONTUSE[torch.float],
        )

    def test_conv_tbc(self):
        with set_default_dtype(torch.double):
            inp = torch.randn(9, 4, 5, requires_grad=True)
            weight = torch.randn(3, 5, 6, requires_grad=True)
            bias = torch.randn(6, requires_grad=True)

            gradcheck(
                lambda i, w, b, pad: F.conv_tbc(i, w, b, pad), (inp, weight, bias, 3)
            )

    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
    @unittest.skipIf(not TEST_CUDNN, "needs cudnn")
    def test_grouped_conv_cudnn_nhwc_support(self):
        # in order to catch the hols in grouped convolution in nhwc support for earlier cudnn version
        input = torch.randn((16, 16, 8, 8), dtype=torch.float16, device="cuda").to(
            memory_format=torch.channels_last
        )
        weight = torch.randn((8, 4, 3, 3), dtype=torch.float16, device="cuda").to(
            memory_format=torch.channels_last
        )
        torch.convolution(input, weight, None, (1, 1), (1, 1), (1, 1), False, (0, 0), 4)
        input = torch.randn((16, 8, 8, 8), dtype=torch.float16, device="cuda").to(
            memory_format=torch.channels_last
        )
        torch.convolution(input, weight, None, (1, 1), (1, 1), (1, 1), True, (0, 0), 4)

    @unittest.expectedFailure
    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
    @unittest.skipIf(not TEST_CUDNN, "needs cudnn")
    def test_conv_cudnn_memory_layout_dominance(self):
        # desired behavior here is to have the memory_layout of conv.weight to
        # dominate the layout of output.
        # which is not the same as current behavior, we'll fix this in
        # following up PRs and remove the `expectedFailure` tag
        input = torch.randint(
            1, 10, (2, 8, 4, 4), dtype=torch.float32, device="cuda", requires_grad=True
        )
        conv = nn.Conv2d(8, 4, 3).cuda().float()

        out = conv(input)
        self.assertTrue(out.is_contiguous())

        input = input.contiguous(memory_format=torch.channels_last)
        out = conv(input)
        self.assertTrue(out.is_contiguous())

        conv.weight.data = conv.weight.contiguous(memory_format=torch.channels_last)
        out = conv(input)
        self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))

        input = input.contiguous()
        out = conv(input)
        self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))

    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
    def test_cudnn_noncontiguous_weight(self):
        # Noncontiguous weights must be contiguous() before being
        # passed to cuDNN
        input = torch.tensor([1, 1, 1], dtype=torch.double, device="cuda").view(1, 1, 3)
        weights1 = torch.tensor([1], dtype=torch.double, device="cuda").expand(1, 1, 2)
        weights2 = (
            torch.tensor([1], dtype=torch.double, device="cuda")
            .expand(1, 1, 2)
            .contiguous()
        )
        self.assertEqual(
            F.conv1d(input, weights1, bias=None, stride=2, dilation=2),
            F.conv1d(input, weights2, bias=None, stride=2, dilation=2),
        )

    def run_grad_conv_test(self, func_forward, func_backward, dim=1, gradient="input"):
        for kern, inp_size in [(3, 6), (3, 7), (4, 9)]:
            for batch, stride, padding, chan_in, chan_out, dilation in product(
                [1, 2], [1, 2], [0, 1, 2], [2], [3], [1]
            ):
                for has_bias in [True, False]:
                    input_shape = [batch, chan_in]
                    weight_shape = [chan_out, chan_in]
                    for _ in range(dim):
                        input_shape.append(inp_size)
                        weight_shape.append(kern)

                    input = torch.randn(input_shape, requires_grad=True)
                    weight = torch.randn(weight_shape, requires_grad=True)
                    if has_bias:
                        bias = torch.randn([chan_out], requires_grad=True)
                    output = func_forward(
                        input,
                        weight,
                        stride=stride,
                        padding=padding,
                        dilation=dilation,
                        bias=bias,
                    )

                    gradient_o = torch.randn(output.shape)
                    gradient_w = torch.autograd.grad(
                        output, input if (gradient == "input") else weight, gradient_o
                    )

                    self.assertEqual(
                        gradient_w[0],
                        func_backward(
                            input_shape if (gradient == "input") else input,
                            weight_shape if (gradient == "weight") else weight,
                            gradient_o,
                            stride=stride,
                            padding=padding,
                            dilation=dilation,
                        ),
                    )

    def test_grad_conv1d_input(self):
        self.run_grad_conv_test(F.conv1d, F.grad.conv1d_input, 1, "input")

    def test_grad_conv1d_weight(self):
        self.run_grad_conv_test(F.conv1d, F.grad.conv1d_weight, 1, "weight")

    def test_grad_conv2d_input(self):
        self.run_grad_conv_test(F.conv2d, F.grad.conv2d_input, 2, "input")

    def test_grad_conv2d_weight(self):
        self.run_grad_conv_test(F.conv2d, F.grad.conv2d_weight, 2, "weight")

    def test_grad_conv3d_input(self):
        self.run_grad_conv_test(F.conv3d, F.grad.conv3d_input, 3, "input")

    def test_grad_conv3d_weight(self):
        self.run_grad_conv_test(F.conv3d, F.grad.conv3d_weight, 3, "weight")

    @unittest.skipIf(not torch._nnpack_available(), "NNPACK unavailable")
    def test_nnpack_conv(self):
        for kern, inp_size in [(3, 6), (3, 7), (4, 9)]:
            for batch, stride, padding, chan_in, chan_out in product(
                [1, 2, 3, 4], [1, 2], [0, 1, 2], [2], [3]
            ):
                for has_bias in [True, False]:
                    input_shape = [batch, chan_in]
                    weight_shape = [chan_out, chan_in]
                    for _ in range(2):
                        input_shape.append(inp_size)
                        weight_shape.append(kern)

                    input = torch.randn(
                        input_shape, requires_grad=True, dtype=torch.float
                    )
                    weight = torch.randn(
                        weight_shape, requires_grad=True, dtype=torch.float
                    )
                    if has_bias:
                        bias = torch.randn(
                            [chan_out], requires_grad=True, dtype=torch.float
                        )
                    output = torch._nnpack_spatial_convolution(
                        input, weight, stride=stride, padding=padding, bias=bias
                    )
                    output_expected = torch.nn.functional.conv2d(
                        input, weight, stride=stride, padding=padding, bias=bias
                    )
                    self.assertEqual(output, output_expected, atol=3e-4, rtol=0)

                    gradient_o = torch.randn(output.shape, dtype=torch.float)

                    grads = torch.autograd.grad(output, [input, weight], gradient_o)
                    grads_expected = torch.autograd.grad(
                        output_expected, [input, weight], gradient_o
                    )
                    for gr, gr_expected in zip(grads, grads_expected):
                        self.assertEqual(gr, gr_expected, atol=3e-4, rtol=0)

    def test_conv_padding_mode(self):
        with self.assertRaisesRegex(ValueError, "padding_mode must be one of"):
            nn.Conv2d(3, 3, 3, padding_mode="xyz")

        with self.assertRaisesRegex(ValueError, "padding_mode must be one of"):
            nn.Conv2d(3, 3, 3, padding_mode=3)

        with self.assertRaisesRegex(ValueError, 'Only "zeros" '):
            nn.ConvTranspose2d(3, 3, 3, padding_mode="reflect")

    def test_functional_grad_conv(self):
        # Conv 1D
        input = torch.randn(1, 1, 5, requires_grad=True)
        weight = torch.randn(1, 1, 3, requires_grad=True)
        output = F.conv1d(input, weight, dilation=2)
        grad_output = torch.randn(output.shape)

        grad_input_autograd, grad_weight_autograd = torch.autograd.grad(
            output, (input, weight), grad_output
        )

        grad_input_functional = torch.nn.grad.conv1d_input(
            input.shape, weight, grad_output, dilation=2
        )
        self.assertEqual(grad_input_functional, grad_input_autograd)

        grad_weight_functional = torch.nn.grad.conv1d_weight(
            input, weight.shape, grad_output, dilation=2
        )
        self.assertEqual(grad_weight_functional, grad_weight_autograd)

        # Conv 2D
        input = torch.randn(1, 1, 5, 5, requires_grad=True)
        weight = torch.randn(1, 1, 3, 3, requires_grad=True)
        output = F.conv2d(input, weight, dilation=2)
        grad_output = torch.randn(output.shape)

        (grad_input_autograd, grad_weight_autograd) = torch.autograd.grad(
            output, (input, weight), grad_output
        )

        grad_input_functional = torch.nn.grad.conv2d_input(
            input.shape, weight, grad_output, dilation=2
        )
        self.assertEqual(grad_input_functional, grad_input_autograd)

        grad_weight_functional = torch.nn.grad.conv2d_weight(
            input, weight.shape, grad_output, dilation=2
        )
        self.assertEqual(grad_weight_functional, grad_weight_autograd)

        # Conv 3D
        input = torch.randn(1, 1, 5, 5, 5, requires_grad=True)
        weight = torch.randn(1, 1, 3, 3, 3, requires_grad=True)
        output = F.conv3d(input, weight, dilation=2)
        grad_output = torch.randn(output.shape)

        (grad_input_autograd, grad_weight_autograd) = torch.autograd.grad(
            output, (input, weight), grad_output
        )

        grad_input_functional = torch.nn.grad.conv3d_input(
            input.shape, weight, grad_output, dilation=2
        )
        self.assertEqual(grad_input_functional, grad_input_autograd)

        grad_weight_functional = torch.nn.grad.conv3d_weight(
            input, weight.shape, grad_output, dilation=2
        )
        self.assertEqual(grad_weight_functional, grad_weight_autograd)

    def test_functional_grad_conv2d(self):
        BATCH_SIZE = 4
        IN_CH = 8
        OUT_CH = 16
        SPATIAL = 32

        def _test_conv2d(stride, kernel_size, groups, dilation):
            padding = kernel_size // 2

            input = (
                torch.empty(BATCH_SIZE, IN_CH, SPATIAL, SPATIAL)
                .uniform_(-8.0, 8.0)
                .requires_grad_(True)
            )

            weight = (
                torch.empty(OUT_CH, IN_CH // groups, kernel_size, kernel_size)
                .uniform_(-4.0, 4.0)
                .requires_grad_(True)
            )

            output = F.conv2d(
                input,
                weight,
                stride=stride,
                padding=padding,
                dilation=dilation,
                groups=groups,
            )

            grad_output = torch.randn(output.shape)

            (grad_input_autograd, grad_weight_autograd) = torch.autograd.grad(
                output, (input, weight), grad_output
            )

            grad_input_functional = torch.nn.grad.conv2d_input(
                input.shape,
                weight,
                grad_output,
                stride=stride,
                padding=padding,
                dilation=dilation,
                groups=groups,
            )
            self.assertEqual(grad_input_functional, grad_input_autograd)

            grad_weight_functional = torch.nn.grad.conv2d_weight(
                input,
                weight.shape,
                grad_output,
                stride=stride,
                padding=padding,
                dilation=dilation,
                groups=groups,
            )
            self.assertEqual(grad_weight_functional, grad_weight_autograd)

        strides = [1, 2]
        kernel_sizes = [1, 3, 5]
        groups = [1, 2, 4]
        dilates = [1, 2]

        for s, k, g, d in product(strides, kernel_sizes, groups, dilates):
            _test_conv2d(s, k, g, d)

    def test_permute_conv2d_issue_120211(self):
        def reproducer(radius: int):
            image = torch.rand(1, 1024, 1024, 3)
            image = image.permute(0, 3, 1, 2)
            kernel_x = torch.zeros([3, 1, 1, radius * 2 + 1], device=image.device)
            image = torch.nn.functional.conv2d(image, kernel_x, groups=image.shape[-3])

        for i in range(128):
            # This should not fail
            reproducer(radius=i)

    def test_conv3d_issue_120406(self):
        # This should not fail
        F.conv3d(torch.ones(2, 3, 8, 9, 26), torch.ones(3, 1, 1, 1, 17), groups=3)

    def test_conv1d_issue_120547(self):
        weight = torch.ones([16, 1, 32])
        bias = torch.ones([16])
        stride, padding, dilation, groups = (1, 16, 1, 16)
        input = torch.rand((1, 1, 16))
        input = input.transpose(1, 2)
        # This should not fail
        F.conv1d(input, weight, bias, stride, padding, dilation, groups)


class TestConvolutionNNDeviceType(NNTestCase):
    def run_conv_double_back_test(
        self,
        kern,
        stride,
        padding,
        chan_in,
        chan_out,
        batch_size,
        inp_size,
        dilation,
        no_weight,
        groups=1,
        use_cuda=False,
        use_bias=True,
        dtype=torch.double,
    ):
        if use_cuda:
            device = torch.device("cuda")
        else:
            device = torch.device("cpu")

        x = torch.randn(
            batch_size,
            chan_in,
            inp_size,
            inp_size,
            device=device,
            dtype=dtype,
            requires_grad=True,
        )
        weight = torch.randn(
            chan_out,
            chan_in // groups,
            kern,
            kern,
            device=device,
            dtype=dtype,
            requires_grad=not no_weight,
        )
        if use_bias:
            bias = torch.randn(chan_out, device=device, dtype=dtype, requires_grad=True)
        else:
            bias = None

        def func(*inputs):
            if use_bias:
                lx, lweight, lbias = inputs
            else:
                lx, lweight = inputs
                lbias = None
            # We disable cudnn during forward to avoid finite difference imprecision issues
            with cudnn.flags(enabled=False):
                out = F.conv2d(lx, lweight, lbias, stride, padding, dilation, groups)
            return out

        if use_bias:
            inputs = x, weight, bias
        else:
            inputs = x, weight

        dummy_out = func(*inputs)
        grad_y = torch.randn_like(
            dummy_out, device=device, dtype=dtype, requires_grad=True
        )

        # Issue #15353: test mkldnn double backward, don't run gradgradcheck due
        # to imprecision issues
        if dtype == torch.float:
            (g,) = torch.autograd.grad(dummy_out.sum(), x, create_graph=True)
            return g.requires_grad

        return gradgradcheck(func, inputs, (grad_y,))

    @onlyCUDA
    @skipCUDAIfNoCudnn
    @dtypes(
        *floating_and_complex_types_and(
            torch.half, *[torch.bfloat16] if AMPERE_OR_ROCM else []
        )
    )
    @parametrize_test("dilation", [1, 2, 3])
    def test_Conv2d_deterministic_cudnn(self, device, dtype, dilation):
        inputs = torch.randn(2, 3, 7, 7, device=device, dtype=dtype, requires_grad=True)
        with cudnn.flags(enabled=True, benchmark=True, deterministic=True):
            conv1 = torch.nn.Conv2d(3, 3, 3, dilation=dilation).to(device, dtype)
            conv2 = torch.nn.Conv2d(3, 3, 3, dilation=dilation).to(device, dtype)
            conv2.bias.data.copy_(conv1.bias.data)
            conv2.weight.data.copy_(conv1.weight.data)
            out1 = conv1(inputs)
            out2 = conv2(inputs)
            self.assertEqual(out1, out2, atol=0.0, rtol=0)
            y = torch.randn(out1.size(), device=device, dtype=dtype)
            out1.backward(y)
            out2.backward(y)
            self.assertEqual(
                conv1.bias.grad.data, conv2.bias.grad.data, atol=0.0, rtol=0
            )
            self.assertEqual(
                conv1.weight.grad.data, conv2.weight.grad.data, atol=0.0, rtol=0
            )

    @onlyCUDA
    @dtypes(
        *floating_types_and(torch.half, *[torch.bfloat16] if AMPERE_OR_ROCM else [])
    )
    def test_Conv2d_large_workspace(self, device, dtype):
        # These sizes require huge cuDNN workspaces. Make sure we choose a
        # reasonable algorithm that does not run out of memory
        sizes = [
            (1, 256, 109, 175),
            (1, 256, 80, 128),
            (1, 256, 120, 192),
        ]

        def run_test(benchmark):
            with torch.backends.cudnn.flags(enabled=True, benchmark=benchmark):
                conv = torch.nn.Conv2d(256, 256, kernel_size=3, padding=1).to(
                    device, dtype
                )
                for size in sizes:
                    x = torch.randn(size, device=device, dtype=dtype)
                    out = conv(x.detach().clone().requires_grad_())
                    out.backward(torch.ones_like(out))

        run_test(benchmark=False)
        run_test(benchmark=True)

    @onlyCUDA
    @dtypes(torch.half, torch.float)
    def test_ConvTranspose2d_large_output_padding(self, device, dtype):
        net1 = torch.nn.ConvTranspose2d(
            128, 64, kernel_size=3, stride=2, padding=1, output_padding=1
        ).to(device=device, dtype=dtype)
        net2 = torch.nn.ConvTranspose2d(
            64, 32, kernel_size=3, stride=2, padding=1, output_padding=1
        ).to(device=device, dtype=dtype)
        net3 = torch.nn.ConvTranspose2d(
            32, 3, kernel_size=3, stride=2, padding=1, output_padding=1
        ).to(device=device, dtype=dtype)
        x = torch.rand(1, 128, 6, 6, device=device, dtype=dtype, requires_grad=True)
        x = net1(x)
        x = net2(x)
        x = net3(x)
        x.backward(torch.randn_like(x))
        torch.cuda.synchronize()

    @onlyCUDA
    @dtypes(torch.float, torch.double, torch.half)
    # Very similar to test_Conv2d_naive_groups but with special care to handle
    # the number of groups == number of input channels
    @torch.backends.cudnn.flags(enabled=True, deterministic=True, benchmark=False)
    @torch.backends.miopen.flags(immediate=True)
    @tf32_on_and_off(0.01)
    def test_Conv2d_depthwise_naive_groups(self, device, dtype):
        for depth_multiplier in [1, 2]:
            m = nn.Conv2d(2, 2 * depth_multiplier, kernel_size=3, groups=2).to(
                device, dtype
            )
            i = (
                torch.randn(2, 2, 6, 6, device="cuda", dtype=dtype)
                .div_(2)
                .requires_grad_()
            )
            output = m(i)
            grad_output = (
                torch.randn(2, 2 * depth_multiplier, 4, 4, device=device, dtype=dtype)
                / 2
            )
            output.backward(grad_output)

            offset = 1 * depth_multiplier

            m1 = nn.Conv2d(1, 1 * depth_multiplier, kernel_size=3).to(device, dtype)
            m1.weight.data = m.weight.data[:offset].clone()
            m1.bias.data = m.bias.data[:offset].clone()
            i1 = i.detach()[:, :1].clone().requires_grad_()
            output1 = m1(i1)
            output1.backward(grad_output[:, :offset].contiguous())

            m2 = nn.Conv2d(1, 1 * depth_multiplier, kernel_size=3).to(device, dtype)
            m2.weight.data.copy_(m.weight.data[offset:])
            m2.bias.data.copy_(m.bias.data[offset:])
            i2 = i.detach()[:, 1:].clone().requires_grad_()
            output2 = m2(i2)
            output2.backward(grad_output[:, offset:].contiguous())

            self.assertEqual(
                output,
                torch.cat([output1, output2], 1),
                atol=dtype2prec_DONTUSE[dtype],
                rtol=0,
            )
            self.assertEqual(
                i.grad.data,
                torch.cat([i1.grad.data, i2.grad.data], 1),
                atol=dtype2prec_DONTUSE[dtype],
                rtol=0,
            )
            self.assertEqual(
                m.bias.grad.data,
                torch.cat([m1.bias.grad.data, m2.bias.grad.data], 0),
                atol=dtype2prec_DONTUSE[dtype],
                rtol=0,
            )
            self.assertEqual(
                m.weight.grad.data,
                torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0),
                atol=dtype2prec_DONTUSE[dtype],
                rtol=0,
            )

    @onlyCUDA
    @dtypes(torch.float, torch.double, torch.half)
    @torch.backends.cudnn.flags(enabled=True, deterministic=True, benchmark=False)
    @torch.backends.miopen.flags(immediate=True)
    @tf32_on_and_off(0.01)
    def test_Conv3d_depthwise_naive_groups(self, device, dtype):
        for depth_multiplier in [1, 2]:
            m = nn.Conv3d(2, 2 * depth_multiplier, kernel_size=3, groups=2).to(
                device, dtype
            )
            i = (
                torch.randn(2, 2, 6, 6, 6, device="cuda", dtype=dtype)
                .div_(2)
                .requires_grad_()
            )
            output = m(i)
            grad_output = (
                torch.randn(
                    2, 2 * depth_multiplier, 4, 4, 4, device=device, dtype=dtype
                )
                / 2
            )
            output.backward(grad_output)

            offset = 1 * depth_multiplier

            m1 = nn.Conv3d(1, 1 * depth_multiplier, kernel_size=3).to(device, dtype)
            m1.weight.data = m.weight.data[:offset].clone()
            m1.bias.data = m.bias.data[:offset].clone()
            i1 = i.detach()[:, :1].clone().requires_grad_()
            output1 = m1(i1)
            output1.backward(grad_output[:, :offset].contiguous())

            m2 = nn.Conv3d(1, 1 * depth_multiplier, kernel_size=3).to(device, dtype)
            m2.weight.data.copy_(m.weight.data[offset:])
            m2.bias.data.copy_(m.bias.data[offset:])
            i2 = i.detach()[:, 1:].clone().requires_grad_()
            output2 = m2(i2)
            output2.backward(grad_output[:, offset:].contiguous())
            is_cuda_sm86 = device.startswith(
                "cuda"
            ) and torch.cuda.get_device_capability(0) == (8, 6)
            atol, rtol = (
                (3e-4, 3e-2)
                if dtype == torch.float32 and is_cuda_sm86
                else (dtype2prec_DONTUSE[dtype], 0)
            )

            self.assertEqual(
                output, torch.cat([output1, output2], 1), atol=atol, rtol=rtol
            )
            self.assertEqual(
                i.grad.data,
                torch.cat([i1.grad.data, i2.grad.data], 1),
                atol=dtype2prec_DONTUSE[dtype],
                rtol=0,
            )
            self.assertEqual(
                m.bias.grad.data,
                torch.cat([m1.bias.grad.data, m2.bias.grad.data], 0),
                atol=dtype2prec_DONTUSE[dtype],
                rtol=0,
            )
            self.assertEqual(
                m.weight.grad.data,
                torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0),
                atol=atol,
                rtol=rtol,
            )

    @onlyCUDA
    @dtypes(
        *floating_types_and(torch.half, *[torch.bfloat16] if AMPERE_OR_ROCM else [])
    )
    def test_noncontig_conv_grad(self, device, dtype):
        # FIXME: remove after adding non-contiguous grad tests for all modules
        module = nn.Conv2d(3, 5, kernel_size=3, padding=1).to(device, dtype)
        input = torch.randn(
            2, 3, 10, 10, dtype=dtype, device=device, requires_grad=True
        )
        output = module(input)

        grad = torch.randn(2, 2, 5, 10, 10, dtype=dtype, device=device)[:, 1]
        assert not grad.is_contiguous()
        output.backward(grad, retain_graph=True)
        self.assertIsNotNone(input.grad)
        result = input.grad.data.clone()
        input.grad.data.zero_()

        output.backward(grad.contiguous())
        self.assertEqual(
            result, input.grad.data, atol=dtype2prec_DONTUSE[dtype], rtol=0
        )

    @onlyCUDA
    @dtypes(torch.double)
    @torch.backends.cudnn.flags(enabled=True, deterministic=True, benchmark=False)
    @torch.backends.miopen.flags(immediate=True)
    def test_conv_double_backward(self, device, dtype):
        # Double backward only runs with DoubleTensor due to precision reason
        batch_size = 1
        for kern, inp_size, dilations in [(3, 5, [1, 2]), (4, 9, [1])]:
            for stride, padding, chan_in, chan_out, dilation in product(
                [1], [2], [2], [3], dilations
            ):
                no_weight = stride == 2
                result = self.run_conv_double_back_test(
                    kern,
                    stride,
                    padding,
                    chan_in,
                    chan_out,
                    batch_size,
                    inp_size,
                    dilation,
                    no_weight,
                    use_cuda=True,
                    dtype=dtype,
                )
                self.assertTrue(
                    result,
                    "Conv double backward test failed with parameters:"
                    + "\nkern: "
                    + str(kern)
                    + "\nstride: "
                    + str(stride)
                    + "\npadding: "
                    + str(padding)
                    + "\nchan_in: "
                    + str(chan_in)
                    + "\nchan_out: "
                    + str(chan_out)
                    + "\nbatch_size: "
                    + str(batch_size)
                    + "\ninp_size: "
                    + str(inp_size)
                    + "\ndilation: "
                    + str(dilation),
                )

    def test_conv_double_backward_no_bias(self):
        kern = 3
        stride = 2
        chan_in, chan_out = 2, 4
        batch_size = 2
        inp_size = 5
        padding = 1
        dilation = 1
        no_weight = False
        use_bias = True
        result = self.run_conv_double_back_test(
            kern,
            stride,
            padding,
            chan_in,
            chan_out,
            batch_size,
            inp_size,
            dilation,
            no_weight,
            use_bias=use_bias,
        )
        self.assertTrue(
            result,
            "Conv double backward test failed with parameters:"
            + "\nkern: "
            + str(kern)
            + "\nstride: "
            + str(stride)
            + "\npadding: "
            + str(padding)
            + "\nchan_in: "
            + str(chan_in)
            + "\nchan_out: "
            + str(chan_out)
            + "\nbatch_size: "
            + str(batch_size)
            + "\ninp_size: "
            + str(inp_size)
            + "\ndilation: "
            + str(dilation),
        )

    def test_conv_double_backward_groups(self):
        kern = 3
        stride = 1
        padding = 2
        chan_in, chan_out = 2, 4
        batch_size = 2
        inp_size = 6
        dilation = 1
        no_weight = False
        groups = 2
        result = self.run_conv_double_back_test(
            kern,
            stride,
            padding,
            chan_in * groups,
            chan_out * groups,
            batch_size,
            inp_size,
            dilation,
            no_weight,
            groups=groups,
        )
        self.assertTrue(
            result,
            "Conv double backward test failed with parameters:"
            + "\nkern: "
            + str(kern)
            + "\nstride: "
            + str(stride)
            + "\npadding: "
            + str(padding)
            + "\nchan_in: "
            + str(chan_in)
            + "\nchan_out: "
            + str(chan_out)
            + "\nbatch_size: "
            + str(batch_size)
            + "\ninp_size: "
            + str(inp_size)
            + "\ndilation: "
            + str(dilation)
            + "\ngroups: "
            + str(groups),
        )

    def test_conv_double_backward_stride(self):
        batch_size = 2

        # Cannot provide ggW when stride is > 1
        for kern, inp_size, dilations in [(3, 5, [1, 2]), (3, 7, [1])]:
            for stride, padding, chan_in, chan_out, dilation in product(
                [2], [0, 1], [1], [2], dilations
            ):
                no_weight = False
                self.run_conv_double_back_test(
                    kern,
                    stride,
                    padding,
                    chan_in,
                    chan_out,
                    batch_size,
                    inp_size,
                    dilation,
                    no_weight,
                )

    @dtypes(torch.float, torch.cfloat)
    @dtypesIfMPS(
        *([torch.float] if MACOS_VERSION < 14.0 else [torch.float, torch.cfloat])
    )  # Complex not supported on MacOS13
    @torch.backends.cudnn.flags(enabled=True, deterministic=True, benchmark=False)
    @torch.backends.miopen.flags(immediate=True)
    def test_conv1d_same_padding(self, device, dtype):
        # Test padding='same' outputs the correct shape
        test_args = [
            # in_size
            range(50, 55),
            # kernel_size
            [1, 2, 3, 8],
            # dilation
            range(1, 4),
            # stride
            [1],
        ]
        for in_size, k_size, dilation, stride in itertools.product(*test_args):
            x = torch.rand(1, 1, in_size, device=device, dtype=dtype)
            y = torch.rand(1, 1, k_size, device=device, dtype=dtype)
            z = F.conv1d(x, y, padding="same", dilation=dilation, stride=stride)
            self.assertEqual(z.size(2), int(math.ceil(in_size / stride)))

        # Compare F.conv1d padding='same' output against manual padding
        # Without strides/dilation
        x = torch.rand(1, 1, 12, device=device, dtype=dtype)
        y = torch.rand(1, 1, 3, device=device, dtype=dtype)
        expect = F.conv1d(x, y, padding=1)
        actual = F.conv1d(x, y, padding="same")
        self.assertEqual(expect, actual)

        # With dilation
        x = torch.rand(1, 1, 12, device=device, dtype=dtype)
        y = torch.rand(1, 1, 4, device=device, dtype=dtype)
        expect = F.conv1d(x, y, padding=3, dilation=2)
        actual = F.conv1d(x, y, padding="same", dilation=2)
        self.assertEqual(expect, actual)

        # Dilation with asymmetric padding
        expect = F.conv1d(x, y, padding=5, dilation=3)[..., 1:]
        actual = F.conv1d(x, y, padding="same", dilation=3)
        self.assertEqual(expect, actual)

    @tf32_on_and_off(0.005)
    @dtypesIfMPS(
        *([torch.float] if MACOS_VERSION < 14.0 else [torch.float, torch.cfloat])
    )  # Complex not supported on MacOS13
    @dtypes(torch.float, torch.cfloat)
    def test_conv2d_same_padding(self, device, dtype):
        # Compare F.conv2d padding='same' output against manual padding
        # Without strides/dilation
        x = torch.rand(1, 1, 10, 11, device=device, dtype=dtype)
        y = torch.rand(1, 1, 4, 5, device=device, dtype=dtype)
        expect = F.conv2d(x, y, padding=(2, 2))[..., 1:, :]
        actual = F.conv2d(x, y, padding="same")
        self.assertEqual(expect, actual)

        # With dilation
        y = torch.rand(1, 1, 3, 4, device=device, dtype=dtype)
        expect = F.conv2d(x, y, padding=(2, 3), dilation=2)
        actual = F.conv2d(x, y, padding="same", dilation=2)
        self.assertEqual(expect, actual)

        # Dilation with asymmetric padding
        y = torch.rand(1, 1, 4, 4, device=device, dtype=dtype)
        expect = F.conv2d(x, y, padding=5, dilation=3)[..., 1:, 1:]
        actual = F.conv2d(x, y, padding="same", dilation=3)
        self.assertEqual(expect, actual)

    @dtypes(torch.float, torch.cfloat)
    def test_conv3d_same_padding(self, device, dtype):
        if dtype is torch.cfloat:
            rtol, atol = 2e-6, 2e-6
        else:
            rtol, atol = None, None
        # Compare F.conv3d padding='same' output against manual padding
        # Without strides/dilation
        x = torch.rand(1, 1, 10, 11, 12, device=device, dtype=dtype)
        y = torch.rand(1, 1, 1, 2, 5, device=device, dtype=dtype)
        expect = F.conv3d(x, y, padding=(0, 1, 2))[..., :, 1:, :]
        actual = F.conv3d(x, y, padding="same")
        self.assertEqual(expect, actual, rtol=rtol, atol=atol)

        # With dilation
        expect = F.conv3d(x, y, padding=(0, 1, 4), dilation=2)
        actual = F.conv3d(x, y, padding="same", dilation=2)
        self.assertEqual(expect, actual, rtol=rtol, atol=atol)

        # Dilation with asymmetric padding
        y = torch.rand(1, 1, 4, 4, 4, device=device, dtype=dtype)
        expect = F.conv3d(x, y, padding=5, dilation=3)[..., 1:, 1:, 1:]
        actual = F.conv3d(x, y, padding="same", dilation=3)
        self.assertEqual(expect, actual, rtol=rtol, atol=atol)

    @dtypes(torch.float, torch.cfloat)
    @dtypesIfMPS(
        *([torch.float] if MACOS_VERSION < 14.0 else [torch.float, torch.cfloat])
    )  # Complex not supported on MacOS13
    def test_conv1d_valid_padding(self, device, dtype):
        # Test F.conv1d padding='valid' is the same as no padding
        x = torch.rand(1, 1, 10, device=device, dtype=dtype)
        y = torch.rand(1, 1, 4, device=device, dtype=dtype)
        expect = F.conv1d(x, y)
        actual = F.conv1d(x, y, padding="valid")
        self.assertEqual(expect, actual)

    @dtypes(torch.float, torch.cfloat)
    @dtypesIfMPS(
        *([torch.float] if MACOS_VERSION < 14.0 else [torch.float, torch.cfloat])
    )  # Complex not supported on MacOS13
    def test_conv2d_valid_padding(self, device, dtype):
        # Test F.conv2d padding='valid' is the same as no padding
        x = torch.rand(1, 1, 1, 10, device=device, dtype=dtype)
        y = torch.rand(1, 1, 1, 4, device=device, dtype=dtype)
        expect = F.conv2d(x, y)
        actual = F.conv2d(x, y, padding="valid")
        self.assertEqual(expect, actual)

    @dtypes(torch.float, torch.cfloat)
    def test_conv3d_valid_padding(self, device, dtype):
        # Test F.conv3d padding='valid' is the same as no padding
        x = torch.rand(1, 1, 1, 1, 10, dtype=dtype, device=device)
        y = torch.rand(1, 1, 1, 1, 4, dtype=dtype, device=device)
        expect = F.conv3d(x, y)
        actual = F.conv3d(x, y, padding="valid")
        self.assertEqual(expect, actual)

    @dtypes(torch.float, torch.cfloat)
    @dtypesIfMPS(
        *([torch.float] if MACOS_VERSION < 14.0 else [torch.float, torch.cfloat])
    )  # Complex not supported on MacOS13
    def test_conv1d_same_padding_backward(self, device, dtype):
        # Test F.conv1d gradients work with padding='same'
        x = torch.rand(1, 1, 12, dtype=dtype, device=device, requires_grad=True)
        y = torch.rand(1, 1, 4, dtype=dtype, device=device, requires_grad=True)

        # Symmetric padding
        z = F.conv1d(x, y, padding=3, dilation=2)
        z.sum().abs().backward()
        gx_expect, gy_expect = x.grad, y.grad
        x.grad, y.grad = None, None

        z = F.conv1d(x, y, padding="same", dilation=2)
        z.sum().abs().backward()
        self.assertEqual(gx_expect, x.grad)
        self.assertEqual(gy_expect, y.grad)
        x.grad, y.grad = None, None

        # Asymmetric padding
        z = F.conv1d(x, y, padding=2)[..., 1:]
        z.sum().abs().backward()
        gx_expect, gy_expect = x.grad, y.grad
        x.grad, y.grad = None, None

        z = F.conv1d(x, y, padding="same")
        z.sum().abs().backward()
        self.assertEqual(gx_expect, x.grad)
        self.assertEqual(gy_expect, y.grad)

    @dtypes(torch.float, torch.cfloat)
    @dtypesIfMPS(
        *([torch.float] if MACOS_VERSION < 14.0 else [torch.float, torch.cfloat])
    )  # Complex not supported on MacOS13
    @tf32_on_and_off(0.001)
    def test_conv2d_same_padding_backward(self, device, dtype):
        # Test F.conv2d gradients work with padding='same'
        x = torch.rand(1, 1, 10, 11, device=device, dtype=dtype, requires_grad=True)
        y = torch.rand(1, 1, 4, 5, device=device, dtype=dtype, requires_grad=True)

        # Symmetric padding
        z = F.conv2d(x, y, padding=(3, 4), dilation=2)
        z.sum().abs().backward()
        gx_expect, gy_expect = x.grad, y.grad
        x.grad, y.grad = None, None

        z = F.conv2d(x, y, padding="same", dilation=2)
        z.sum().abs().backward()
        self.assertEqual(gx_expect, x.grad)
        self.assertEqual(gy_expect, y.grad)
        x.grad, y.grad = None, None

        # Asymmetric padding
        y = torch.rand(1, 1, 4, 4, device=device, dtype=dtype, requires_grad=True)
        z = F.conv2d(x, y, padding=2)[..., 1:, 1:]
        z.sum().abs().backward()
        gx_expect, gy_expect = x.grad, y.grad
        x.grad, y.grad = None, None

        z = F.conv2d(x, y, padding="same")
        z.sum().abs().backward()
        self.assertEqual(gx_expect, x.grad)
        self.assertEqual(gy_expect, y.grad)

    @dtypes(torch.double, torch.cdouble)
    @dtypesIfMPS(
        torch.float, torch.cfloat
    )  # Double, complex double not supported on MPS
    @expectedFailureMPS  # https://github.com/pytorch/pytorch/issues/107214
    def test_conv3d_same_padding_backward(self, device, dtype):
        check_forward_ad = torch.device(device).type != "xla"

        # Test F.conv3d gradients work with padding='same'
        x = torch.rand(1, 1, 1, 11, 12, dtype=dtype, device=device, requires_grad=True)
        y = torch.rand(1, 1, 1, 2, 5, dtype=dtype, device=device, requires_grad=True)

        # Symmetric padding
        z = F.conv3d(x, y, padding=(0, 1, 4), dilation=2)
        z.sum().abs().backward()
        gx_expect, gy_expect = x.grad, y.grad
        x.grad, y.grad = None, None

        z = F.conv3d(x, y, padding="same", dilation=2)
        z.sum().abs().backward()
        self.assertEqual(gx_expect, x.grad)
        self.assertEqual(gy_expect, y.grad)
        x.grad, y.grad = None, None

        gradcheck(
            lambda x, y: F.conv3d(x, y, padding="same", dilation=2),
            (x, y),
            check_forward_ad=check_forward_ad,
            nondet_tol=1e-5,
        )
        if torch.device(device).type != "cuda":
            # https://github.com/pytorch/pytorch/issues/70702
            gradgradcheck(
                lambda x, y: F.conv3d(x, y, padding="same", dilation=2),
                (x, y),
                check_fwd_over_rev=True,
            )

        # Asymmetric padding
        y = torch.rand(1, 1, 1, 4, 4, dtype=dtype, device=device, requires_grad=True)
        z = F.conv3d(x, y, padding=2)[..., 1:, 1:]
        z.sum().abs().backward()
        gx_expect, gy_expect = x.grad, y.grad
        x.grad, y.grad = None, None

        z = F.conv3d(x, y, padding="same")
        z.sum().abs().backward()
        self.assertEqual(gx_expect, x.grad)
        self.assertEqual(gy_expect, y.grad)

        gradcheck(
            lambda x, y: F.conv3d(x, y, padding="same"),
            (x, y),
            check_forward_ad=check_forward_ad,
            nondet_tol=1e-5,
        )
        if torch.device(device).type != "cuda":
            # https://github.com/pytorch/pytorch/issues/70702
            gradgradcheck(
                lambda x, y: F.conv3d(x, y, padding="same"),
                (x, y),
                check_fwd_over_rev=True,
            )

    @dtypes(torch.float, torch.cfloat)
    @dtypesIfMPS(
        *([torch.float] if MACOS_VERSION < 14.0 else [torch.float, torch.cfloat])
    )  # Complex not supported on MacOS13
    def test_conv1d_valid_padding_backward(self, device, dtype):
        # Test F.conv1d gradients work with padding='valid'
        x = torch.rand(1, 1, 10, dtype=dtype, device=device, requires_grad=True)
        y = torch.rand(1, 1, 4, dtype=dtype, device=device, requires_grad=True)
        F.conv1d(x, y, padding=0).sum().abs().backward()
        gx_expect, gy_expect = x.grad, y.grad
        x.grad, y.grad = None, None

        F.conv1d(x, y, padding="valid").sum().abs().backward()
        gx_actual, gy_actual = x.grad, y.grad
        self.assertEqual(gx_expect, gx_actual)
        self.assertEqual(gy_expect, gy_actual)

    @unittest.skipIf(not TEST_SCIPY, "Scipy required for the test.")
    @dtypes(torch.float, torch.cfloat)
    @dtypesIfMPS(
        *([torch.float] if MACOS_VERSION < 14.0 else [torch.float, torch.cfloat])
    )  # Complex not supported on MacOS13
    @parametrize_test("mode", ("valid", "same"))
    def test_conv1d_vs_scipy(self, device, dtype, mode):
        t = make_tensor((1, 10), device=device, dtype=dtype)
        feat_dim = t.shape[1]
        weight_even = make_tensor((1, 1, 4), device=device, dtype=dtype)
        weight_odd = make_tensor((1, 1, 5), device=device, dtype=dtype)

        def _test(t, weight, mode):
            # SciPy expects two 1-D inputs.
            t_a = t.view(-1).cpu().numpy()
            w_a = weight.view(-1).cpu().numpy()
            expected = scipy.signal.convolve(t_a, w_a, mode=mode)

            kwargs = {"padding": mode}
            if mode == "same":
                # `same` padding in PyTorch conv1d is different
                # from SciPy
                p = weight.shape[2] // 2
                t = torch.nn.functional.pad(t, (p, p))
                # We have already taken care of padding
                kwargs.pop("padding")

            # second input is flipped in SciPy's convolve
            weight_flipped = torch.flip(weight, (2,))
            actual = torch.nn.functional.conv1d(t, weight_flipped, **kwargs).squeeze(0)
            if mode == "same":
                actual = actual[:feat_dim]

            self.assertEqual(actual, expected, atol=2e-5, rtol=2e-5)

        # Global dtype for this test suite is torch.double
        # This leads to change in type-promotion
        # and conv1d outputs `complex128` for `complex64` input.
        with set_default_dtype(torch.float):
            _test(t, weight_even, mode)
            _test(t, weight_odd, mode)

    @unittest.skipIf(not TEST_SCIPY, "Scipy required for the test.")
    @dtypes(torch.float, torch.cfloat)
    @dtypesIfMPS(
        *([torch.float] if MACOS_VERSION < 14.0 else [torch.float, torch.cfloat])
    )  # Complex not supported on MacOS13
    @parametrize_test("mode", ("valid", "same"))
    def test_conv2d_vs_scipy(self, device, dtype, mode):
        t = make_tensor((1, 5, 10), device=device, dtype=dtype)
        weight_even = make_tensor((1, 1, 2, 4), device=device, dtype=dtype)
        weight_odd = make_tensor((1, 1, 3, 5), device=device, dtype=dtype)

        def _test(t, weight, mode):
            # SciPy expects two 2-D inputs.
            t_a = t.squeeze(0).cpu().numpy()
            w_a = weight.squeeze(0).squeeze(0).cpu().numpy()
            expected = scipy.signal.convolve2d(t_a, w_a, mode=mode)

            kwargs = {"padding": mode}
            if mode == "same":
                # `same` padding in PyTorch conv2d is different
                # from SciPy
                left_right_pad = weight.shape[3] // 2
                top_bottom_pad = weight.shape[2] // 2
                p = (left_right_pad, left_right_pad, top_bottom_pad, top_bottom_pad)
                t = torch.nn.functional.pad(t, p)
                # We have already taken care of padding
                kwargs.pop("padding")

            # second input is flipped in SciPy's convolve2d
            weight_flipped = torch.flip(weight, (2, 3))
            actual = torch.nn.functional.conv2d(t, weight_flipped, **kwargs).squeeze(0)
            if mode == "same":
                actual = actual[:5, :10]

            self.assertEqual(actual, expected, rtol=2e-5, atol=5e-6)

        # Global dtype for this test suite is torch.double
        # This leads to change in type-promotion
        # and conv1d outputs `complex128` for `complex64` input.
        with set_default_dtype(torch.float):
            _test(t, weight_even, mode)
            _test(t, weight_odd, mode)

    @unittest.skipIf(not TEST_SCIPY, "Scipy required for the test.")
    @skipMPS  # Results in CI are inconsistent, forced to skip
    @dtypes(torch.float, torch.cfloat)
    @parametrize_test("mode", ("valid", "same"))
    def test_conv3d_vs_scipy(self, device, dtype, mode):
        t = make_tensor((1, 5, 5, 10), device=device, dtype=dtype)
        weight_even = make_tensor((1, 1, 2, 2, 4), device=device, dtype=dtype)
        weight_odd = make_tensor((1, 1, 2, 3, 5), device=device, dtype=dtype)

        def _test(t, weight, mode):
            # SciPy expects two 3-D inputs.
            t_a = t.squeeze(0).cpu().numpy()
            w_a = weight.squeeze(0).squeeze(0).cpu().numpy()
            expected = scipy.signal.convolve(t_a, w_a, mode=mode)

            kwargs = {"padding": mode}
            if mode == "same":
                # `same` padding in PyTorch conv3d is different
                # from SciPy
                left_right_pad = weight.shape[4] // 2
                top_bottom_pad = weight.shape[3] // 2
                front_back_pad = weight.shape[2] // 2
                p = (
                    left_right_pad,
                    left_right_pad,
                    top_bottom_pad,
                    top_bottom_pad,
                    front_back_pad,
                    front_back_pad,
                )
                t = torch.nn.functional.pad(t, p)
                # We have already taken care of padding
                kwargs.pop("padding")

            # second input is flipped in SciPy's convolve
            weight_flipped = torch.flip(weight, (2, 3, 4))
            actual = torch.nn.functional.conv3d(t, weight_flipped, **kwargs).squeeze(0)
            if mode == "same":
                actual = actual[:5, :5, :10]

            if torch.cuda.is_tf32_supported() and (
                dtype == torch.float or dtype == torch.complex64
            ):
                self.assertEqual(actual, expected, atol=0.05, rtol=0.05)
            else:
                self.assertEqual(actual, expected, rtol=2e-5, atol=5e-6)

        # Global dtype for this test suite is torch.double
        # This leads to change in type-promotion
        # and conv1d outputs `complex128` for `complex64` input.
        with set_default_dtype(torch.float):
            _test(t, weight_even, mode)
            _test(t, weight_odd, mode)

    @dtypes(torch.float, torch.complex64)
    @dtypesIfMPS(
        *([torch.float] if MACOS_VERSION < 14.0 else [torch.float, torch.cfloat])
    )  # Complex not supported on MacOS13
    def test_conv2d_valid_padding_backward(self, device, dtype):
        # Test F.conv2d gradients work with padding='valid'
        x = torch.rand(1, 1, 1, 10, device=device, dtype=dtype, requires_grad=True)
        y = torch.rand(1, 1, 1, 4, device=device, dtype=dtype, requires_grad=True)
        F.conv2d(x, y, padding=0).sum().abs().backward()
        gx_expect, gy_expect = x.grad, y.grad
        x.grad, y.grad = None, None

        F.conv2d(x, y, padding="valid").sum().abs().backward()
        gx_actual, gy_actual = x.grad, y.grad
        self.assertEqual(gx_expect, gx_actual)
        self.assertEqual(gy_expect, gy_actual)

    @dtypes(torch.double, torch.cdouble)
    @dtypesIfMPS(
        torch.float, torch.cfloat
    )  # Double, complex double not supported on MPS
    @expectedFailureMPS  # https://github.com/pytorch/pytorch/issues/107214
    def test_conv3d_valid_padding_backward(self, device, dtype):
        check_forward_ad = torch.device(device).type != "xla"

        # Test F.conv3d gradients work with padding='valid'
        x = torch.rand(1, 1, 1, 1, 10, dtype=dtype, device=device, requires_grad=True)
        y = torch.rand(1, 1, 1, 1, 4, dtype=dtype, device=device, requires_grad=True)
        F.conv3d(x, y, padding=0).sum().abs().backward()
        gx_expect, gy_expect = x.grad, y.grad
        x.grad, y.grad = None, None

        F.conv3d(x, y, padding="valid").sum().abs().backward()
        gx_actual, gy_actual = x.grad, y.grad
        self.assertEqual(gx_expect, gx_actual)
        self.assertEqual(gy_expect, gy_actual)

        gradcheck(
            lambda x, y: F.conv3d(x, y, padding="valid"),
            (x, y),
            check_forward_ad=check_forward_ad,
        )
        gradgradcheck(
            lambda x, y: F.conv3d(x, y, padding="valid"),
            (x, y),
            check_fwd_over_rev=check_forward_ad,
        )

    @parametrize_test(
        arg_str="N",
        arg_values=[
            subtest(arg_values=(2), name="ConvTranspose2d"),
            subtest(arg_values=(3), name="ConvTranspose3d"),
        ],
    )
    def test_conv_transpose_with_output_size_and_no_batch_dim(self, device, N):
        # For inputs with no batch dim, verify output is the correct shape when output_size is set.
        # See https://github.com/pytorch/pytorch/issues/75889
        inp = torch.randn((1, 15, 13) if N == 2 else (1, 15, 13, 13), device=device)
        output_size = (1, 240, 200) if N == 2 else (1, 240, 200, 200)
        ConvTransposeNd = getattr(nn, f"ConvTranspose{N}d")
        m = ConvTransposeNd(
            1, 1, kernel_size=16, stride=16, padding=7, bias=False, device=device
        )
        output = m(inp, output_size=output_size)
        self.assertEqual(output.shape, output_size)

    @skipMeta
    @parametrize_test(
        "input_shape,transposed,dilated,groups,layout,backend_expected",
        [
            # === slow ===
            subtest(
                (
                    (2, 6, 7),
                    False,
                    False,
                    3,
                    torch.strided,
                    torch._C._ConvBackend.Slow2d,
                ),
                decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN],
                name="slow1d",
            ),
            subtest(
                (
                    (2, 6, 7),
                    True,
                    False,
                    3,
                    torch.strided,
                    torch._C._ConvBackend.SlowTranspose2d,
                ),
                decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN],
                name="slow1d_transposed",
            ),
            subtest(
                (
                    (2, 6, 7),
                    False,
                    True,
                    3,
                    torch.strided,
                    torch._C._ConvBackend.SlowDilated2d,
                ),
                decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN],
                name="slow1d_dilated",
            ),
            subtest(
                (
                    (2, 6, 7),
                    True,
                    True,
                    3,
                    torch.strided,
                    torch._C._ConvBackend.SlowTranspose2d,
                ),
                decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN],
                name="slow1d_dilated_transposed",
            ),
            subtest(
                (
                    (2, 6, 7, 8),
                    False,
                    False,
                    3,
                    torch.strided,
                    torch._C._ConvBackend.Slow2d,
                ),
                decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN],
                name="slow2d",
            ),
            subtest(
                (
                    (2, 6, 7, 8),
                    True,
                    False,
                    3,
                    torch.strided,
                    torch._C._ConvBackend.SlowTranspose2d,
                ),
                decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN],
                name="slow2d_transposed",
            ),
            subtest(
                (
                    (2, 6, 7, 8),
                    False,
                    True,
                    3,
                    torch.strided,
                    torch._C._ConvBackend.SlowDilated2d,
                ),
                decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN],
                name="slow2d_dilated",
            ),
            subtest(
                (
                    (2, 6, 7, 8),
                    True,
                    True,
                    3,
                    torch.strided,
                    torch._C._ConvBackend.SlowTranspose2d,
                ),
                decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN],
                name="slow2d_dilated_transposed",
            ),
            subtest(
                (
                    (2, 6, 7, 8, 9),
                    False,
                    False,
                    3,
                    torch.strided,
                    torch._C._ConvBackend.Slow3d,
                ),
                decorators=[onlyCPU, disableMkldnn],
                name="slow3d_cpu",
            ),
            # CUDA doesn't have a slow 3D implementation, so it goes to the dilated 3D implementation instead
            subtest(
                (
                    (2, 6, 7, 8, 9),
                    False,
                    False,
                    3,
                    torch.strided,
                    torch._C._ConvBackend.SlowDilated3d,
                ),
                decorators=[onlyCUDA, disablecuDNN],
                name="slow3d_cuda",
            ),
            # FIXME: RuntimeError: CUDA out of memory.
            # subtest(((2, 6, 7, 8, 9), True, False, 3, torch.strided, torch._C._ConvBackend.SlowTranspose3d),
            #         decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN], name='slow3d_transposed'),
            subtest(
                (
                    (2, 6, 7, 8, 9),
                    False,
                    True,
                    3,
                    torch.strided,
                    torch._C._ConvBackend.SlowDilated3d,
                ),
                decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN],
                name="slow3d_dilated",
            ),
            # FIXME: RuntimeError: CUDA out of memory.
            # subtest(((2, 6, 7, 8, 9), True, True, 3, torch.strided, torch._C._ConvBackend.SlowTranspose3d),
            #         decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN], name='slow3d_dilated_transposed'),
            subtest(
                (
                    (0, 6, 7),
                    False,
                    False,
                    3,
                    torch.strided,
                    torch._C._ConvBackend.Empty,
                ),
                decorators=[onlyNativeDeviceTypes, disableMkldnn],
                name="empty_batch1d",
            ),
            subtest(
                (
                    (2, 0, 7),
                    False,
                    False,
                    3,
                    torch.strided,
                    torch._C._ConvBackend.Empty,
                ),
                decorators=[onlyNativeDeviceTypes, disableMkldnn],
                name="empty_channel1d",
            ),
            subtest(
                (
                    (0, 0, 7),
                    False,
                    False,
                    3,
                    torch.strided,
                    torch._C._ConvBackend.Empty,
                ),
                decorators=[onlyNativeDeviceTypes, disableMkldnn],
                name="empty_batch_channel1d",
            ),
            subtest(
                (
                    (0, 6, 7, 8),
                    False,
                    False,
                    3,
                    torch.strided,
                    torch._C._ConvBackend.Empty,
                ),
                decorators=[onlyNativeDeviceTypes, disableMkldnn],
                name="empty_batch2d",
            ),
            subtest(
                (
                    (2, 0, 7, 8),
                    False,
                    False,
                    3,
                    torch.strided,
                    torch._C._ConvBackend.Empty,
                ),
                decorators=[onlyNativeDeviceTypes, disableMkldnn],
                name="empty_channel2d",
            ),
            subtest(
                (
                    (0, 0, 7, 8),
                    False,
                    False,
                    3,
                    torch.strided,
                    torch._C._ConvBackend.Empty,
                ),
                decorators=[onlyNativeDeviceTypes, disableMkldnn],
                name="empty_batch_channel2d",
            ),
            subtest(
                (
                    (0, 6, 7, 8, 9),
                    False,
                    False,
                    3,
                    torch.strided,
                    torch._C._ConvBackend.Empty,
                ),
                decorators=[onlyNativeDeviceTypes, disableMkldnn],
                name="empty_batch3d",
            ),
            subtest(
                (
                    (2, 0, 7, 8, 9),
                    False,
                    False,
                    3,
                    torch.strided,
                    torch._C._ConvBackend.Empty,
                ),
                decorators=[onlyNativeDeviceTypes, disableMkldnn],
                name="empty_channel3d",
            ),
            subtest(
                (
                    (0, 0, 7, 8, 9),
                    False,
                    False,
                    3,
                    torch.strided,
                    torch._C._ConvBackend.Empty,
                ),
                decorators=[onlyNativeDeviceTypes, disableMkldnn],
                name="empty_batch_channel3d",
            ),
            # === cuda ===
            # Note that disablecuDNN disables miopen as well.
            subtest(
                (
                    (2, 6, 7),
                    False,
                    False,
                    6,
                    torch.strided,
                    torch._C._ConvBackend.CudaDepthwise2d,
                ),
                decorators=[onlyCUDA, disablecuDNN],
                name="cuda_depthwise1d",
            ),
            subtest(
                (
                    (2, 6, 7, 8),
                    False,
                    False,
                    6,
                    torch.strided,
                    torch._C._ConvBackend.CudaDepthwise2d,
                ),
                decorators=[onlyCUDA, disablecuDNN],
                name="cuda_depthwise2d",
            ),
            subtest(
                (
                    (2, 6, 7, 8, 9),
                    False,
                    False,
                    6,
                    torch.strided,
                    torch._C._ConvBackend.CudaDepthwise3d,
                ),
                decorators=[onlyCUDA, disablecuDNN],
                name="cuda_depthwise3d",
            ),
            # === cudnn ===
            subtest(
                (
                    (2, 6, 7),
                    False,
                    False,
                    3,
                    torch.strided,
                    torch._C._ConvBackend.Cudnn,
                ),
                decorators=[onlyCUDA, skipCUDAIfNoCudnn, skipCUDAIfMiopen],
                name="cudnn1d",
            ),
            subtest(
                (
                    (2, 6, 7, 8),
                    False,
                    False,
                    3,
                    torch.strided,
                    torch._C._ConvBackend.Cudnn,
                ),
                decorators=[onlyCUDA, skipCUDAIfNoCudnn, skipCUDAIfMiopen],
                name="cudnn2d",
            ),
            subtest(
                (
                    (2, 6, 7, 8, 9),
                    False,
                    False,
                    3,
                    torch.strided,
                    torch._C._ConvBackend.Cudnn,
                ),
                decorators=[onlyCUDA, skipCUDAIfNoCudnn, skipCUDAIfMiopen],
                name="cudnn3d",
            ),
            subtest(
                (
                    (2, 6, 7),
                    True,
                    False,
                    3,
                    torch.strided,
                    torch._C._ConvBackend.CudnnTranspose,
                ),
                decorators=[onlyCUDA, skipCUDAIfNoCudnn, skipCUDAIfMiopen],
                name="cudnn1d_transposed",
            ),
            subtest(
                (
                    (2, 6, 7, 8),
                    True,
                    False,
                    3,
                    torch.strided,
                    torch._C._ConvBackend.CudnnTranspose,
                ),
                decorators=[onlyCUDA, skipCUDAIfNoCudnn, skipCUDAIfMiopen],
                name="cudnn2d_transposed",
            ),
            # FIXME: RuntimeError: CUDA out of memory.
            # subtest(((2, 6, 7, 8, 9), True, False, 3, torch.strided, torch._C._ConvBackend.CudnnTranspose),
            #         decorators=[onlyCUDA, skipCUDAIfNoCudnn, skipCUDAIfMiopen], name='cudnn3d_transposed'),
            # === miopen ===
            subtest(
                (
                    (2, 6, 7),
                    False,
                    False,
                    3,
                    torch.strided,
                    torch._C._ConvBackend.Miopen,
                ),
                decorators=[onlyCUDA, skipCUDAIfNoMiopen],
                name="miopen1d",
            ),
            subtest(
                (
                    (2, 6, 7, 8),
                    False,
                    False,
                    3,
                    torch.strided,
                    torch._C._ConvBackend.Miopen,
                ),
                decorators=[onlyCUDA, skipCUDAIfNoMiopen],
                name="miopen2d",
            ),
            subtest(
                (
                    (2, 6, 7, 8, 9),
                    False,
                    False,
                    3,
                    torch.strided,
                    torch._C._ConvBackend.Miopen,
                ),
                decorators=[onlyCUDA, skipCUDAIfNoMiopen],
                name="miopen3d",
            ),
            subtest(
                (
                    (2, 6, 7),
                    True,
                    False,
                    3,
                    torch.strided,
                    torch._C._ConvBackend.MiopenTranspose,
                ),
                decorators=[onlyCUDA, skipCUDAIfNoMiopen],
                name="miopen1d_transposed",
            ),
            subtest(
                (
                    (2, 6, 7, 8),
                    True,
                    False,
                    3,
                    torch.strided,
                    torch._C._ConvBackend.MiopenTranspose,
                ),
                decorators=[onlyCUDA, skipCUDAIfNoMiopen],
                name="miopen2d_transposed",
            ),
            subtest(
                (
                    (2, 6, 7, 8, 9),
                    True,
                    False,
                    3,
                    torch.strided,
                    torch._C._ConvBackend.MiopenTranspose,
                ),
                decorators=[onlyCUDA, skipCUDAIfNoMiopen],
                name="miopen3d_transposed",
            ),
            subtest(
                (
                    (2, 6, 7),
                    False,
                    False,
                    6,
                    torch.strided,
                    torch._C._ConvBackend.MiopenDepthwise,
                ),
                decorators=[onlyCUDA, skipCUDAIfNoMiopen],
                name="miopen_depthwise1d",
            ),
            subtest(
                (
                    (2, 6, 7, 8),
                    False,
                    False,
                    6,
                    torch.strided,
                    torch._C._ConvBackend.MiopenDepthwise,
                ),
                decorators=[onlyCUDA, skipCUDAIfNoMiopen],
                name="miopen_depthwise2d",
            ),
            subtest(
                (
                    (2, 6, 7, 8, 9),
                    False,
                    False,
                    6,
                    torch.strided,
                    torch._C._ConvBackend.MiopenDepthwise,
                ),
                decorators=[onlyCUDA, skipCUDAIfNoMiopen],
                name="miopen_depthwise3d",
            ),
            # === mkldnn ===
            subtest(
                (
                    (2, 6, 7),
                    False,
                    False,
                    3,
                    torch._mkldnn,
                    torch._C._ConvBackend.Mkldnn,
                ),
                decorators=[onlyCPU, skipCPUIfNoMkldnn],
                name="mkldnn1d",
            ),
            subtest(
                (
                    (2, 6, 7, 8),
                    False,
                    False,
                    3,
                    torch._mkldnn,
                    torch._C._ConvBackend.Mkldnn,
                ),
                decorators=[onlyCPU, skipCPUIfNoMkldnn],
                name="mkldnn2d",
            ),
            subtest(
                (
                    (2, 6, 7, 8, 9),
                    False,
                    False,
                    3,
                    torch._mkldnn,
                    torch._C._ConvBackend.Mkldnn,
                ),
                decorators=[onlyCPU, skipCPUIfNoMkldnn],
                name="mkldnn3d",
            ),
            # Transposed convolution is broken for mkldnn. See https://github.com/pytorch/pytorch/issues/68775.
            subtest(
                (
                    (2, 6, 7),
                    True,
                    False,
                    3,
                    torch._mkldnn,
                    torch._C._ConvBackend.Mkldnn,
                ),
                decorators=[onlyCPU, skipCPUIfNoMkldnn, unittest.expectedFailure],
                name="mkldnn1d_transposed",
            ),
            subtest(
                (
                    (2, 6, 7, 8),
                    True,
                    False,
                    3,
                    torch._mkldnn,
                    torch._C._ConvBackend.Mkldnn,
                ),
                decorators=[onlyCPU, skipCPUIfNoMkldnn, unittest.expectedFailure],
                name="mkldnn2d_transposed",
            ),
            subtest(
                (
                    (2, 6, 7, 8, 9),
                    True,
                    False,
                    3,
                    torch._mkldnn,
                    torch._C._ConvBackend.Mkldnn,
                ),
                decorators=[onlyCPU, skipCPUIfNoMkldnn, unittest.expectedFailure],
                name="mkldnn3d_transposed",
            ),
            subtest(
                (
                    (2, 6, 7),
                    False,
                    True,
                    3,
                    torch.strided,
                    torch._C._ConvBackend.Mkldnn,
                ),
                decorators=[onlyCPU, skipCPUIfNoMkldnn],
                name="mkldnn1d_cpu_input",
            ),
            subtest(
                (
                    (2, 6, 7, 8),
                    False,
                    True,
                    3,
                    torch.strided,
                    torch._C._ConvBackend.Mkldnn,
                ),
                decorators=[onlyCPU, skipCPUIfNoMkldnn],
                name="mkldnn2d_cpu_input",
            ),
            subtest(
                (
                    (2, 6, 7, 8, 9),
                    False,
                    True,
                    3,
                    torch.strided,
                    torch._C._ConvBackend.Mkldnn,
                ),
                decorators=[onlyCPU, skipCPUIfNoMkldnn],
                name="mkldnn3d_cpu_input",
            ),
            subtest(
                (
                    (0, 6, 7),
                    False,
                    False,
                    3,
                    torch._mkldnn,
                    torch._C._ConvBackend.MkldnnEmpty,
                ),
                decorators=[onlyCPU, skipCPUIfNoMkldnn],
                name="mkldnn_empty_batch1d",
            ),
            subtest(
                (
                    (2, 0, 7),
                    False,
                    False,
                    3,
                    torch._mkldnn,
                    torch._C._ConvBackend.MkldnnEmpty,
                ),
                decorators=[onlyCPU, skipCPUIfNoMkldnn],
                name="mkldnn_empty_channel1d",
            ),
            subtest(
                (
                    (0, 0, 7),
                    False,
                    False,
                    3,
                    torch._mkldnn,
                    torch._C._ConvBackend.MkldnnEmpty,
                ),
                decorators=[onlyCPU, skipCPUIfNoMkldnn],
                name="mkldnn_empty_batch_channel1d",
            ),
            subtest(
                (
                    (0, 6, 7, 8),
                    False,
                    False,
                    3,
                    torch._mkldnn,
                    torch._C._ConvBackend.MkldnnEmpty,
                ),
                decorators=[onlyCPU, skipCPUIfNoMkldnn],
                name="mkldnn_empty_batch2d",
            ),
            subtest(
                (
                    (2, 0, 7, 8),
                    False,
                    False,
                    3,
                    torch._mkldnn,
                    torch._C._ConvBackend.MkldnnEmpty,
                ),
                decorators=[onlyCPU, skipCPUIfNoMkldnn],
                name="mkldnn_empty_channel2d",
            ),
            subtest(
                (
                    (0, 0, 7, 8),
                    False,
                    False,
                    3,
                    torch._mkldnn,
                    torch._C._ConvBackend.MkldnnEmpty,
                ),
                decorators=[onlyCPU, skipCPUIfNoMkldnn],
                name="mkldnn_empty_batch_channel2d",
            ),
            subtest(
                (
                    (0, 6, 7, 8, 9),
                    False,
                    False,
                    3,
                    torch._mkldnn,
                    torch._C._ConvBackend.MkldnnEmpty,
                ),
                decorators=[onlyCPU, skipCPUIfNoMkldnn],
                name="mkldnn_empty_batch3d",
            ),
            subtest(
                (
                    (2, 0, 7, 8, 9),
                    False,
                    False,
                    3,
                    torch._mkldnn,
                    torch._C._ConvBackend.MkldnnEmpty,
                ),
                decorators=[onlyCPU, skipCPUIfNoMkldnn],
                name="mkldnn_empty_channel3d",
            ),
            subtest(
                (
                    (0, 0, 7, 8, 9),
                    False,
                    False,
                    3,
                    torch._mkldnn,
                    torch._C._ConvBackend.MkldnnEmpty,
                ),
                decorators=[onlyCPU, skipCPUIfNoMkldnn],
                name="mkldnn_empty_batch_channel3d",
            ),
            # Note: Tests for mobile backends are not currently supported. This comprises
            # NnpackSpatial, Winograd3x3Depthwise, and Xnnpack2d backends. Testing these
            # requires the ability to gate tests by whether PyTorch is built with USE_MOBILE=1.
        ],
    )
    # Test with both bias and no bias.
    @parametrize_test("has_bias", [False, True])
    # Test with both stride=1 and stride>1 cases.
    @parametrize_test("strided", [False, True])
    # Test with both contiguous and non-contiguous inputs.
    @parametrize_test("contiguous", [False, True])
    @expectedFailureMPS  # No double support
    def test_conv_backend(
        self,
        device,
        input_shape,
        has_bias,
        strided,
        contiguous,
        transposed,
        dilated,
        groups,
        layout,
        backend_expected,
    ):
        # Build up inputs.
        dtype = torch.float32
        C_in, C_out, dim, kernel_size = input_shape[1], 12, len(input_shape) - 2, 3
        x = torch.randn(*input_shape, device=device, dtype=dtype, requires_grad=True)
        weight = torch.randn(
            C_in if transposed else C_out,
            C_out // groups if transposed else C_in // groups,
            *[kernel_size for _ in range(dim)],
            device=device,
            dtype=dtype,
            requires_grad=True,
        )
        bias = (
            torch.randn(C_out, device=device, dtype=dtype, requires_grad=True)
            if has_bias
            else None
        )

        def _make_noncontiguous(inp):
            if inp is None:
                return None
            old_requires_grad = inp.requires_grad
            inp = torch.repeat_interleave(inp, 2, dim=-1)
            inp = inp[..., ::2].detach().requires_grad_(old_requires_grad)
            return inp

        if not contiguous:
            x = _make_noncontiguous(x)
            weight = _make_noncontiguous(weight)
            bias = _make_noncontiguous(bias)

        if layout is torch._mkldnn:
            x = x.to_mkldnn()
            # Note that weight and bias are not supported as mkldnn tensors during training.

        stride = (2,) * dim if strided else (1,) * dim
        padding = (0,) * dim
        dilation = (2,) * dim if dilated else (1,) * dim
        output_padding = (0,) * dim
        inputs = [
            x,
            weight,
            bias,
            stride,
            padding,
            dilation,
            transposed,
            output_padding,
            groups,
        ]

        # Ensure correct backend is selected.
        backend_actual = torch._C._select_conv_backend(*inputs)
        self.assertEqual(backend_actual, backend_expected)

        # Ensure backward call succeeds.
        convolution = torch.ops.aten.convolution
        output = convolution(*inputs)
        grad_output = torch.randn(output.shape, device=device, dtype=dtype)
        if not contiguous:
            grad_output = _make_noncontiguous(grad_output)
        if layout is torch._mkldnn:
            grad_output = grad_output.to_mkldnn()
        output.backward(grad_output)

        # mkldnn doesn't support gradcheck :(
        if layout is torch._mkldnn:
            return

        if backend_actual != torch._C._ConvBackend.Empty:  # FIXME: forward AD fails
            # Forward AD and forward-over-reverse AD smoke test in float32
            # TODO: remove this if we introduce per-op gradient tests for float32
            with fwAD.dual_level():
                dual_inputs = [
                    (
                        fwAD.make_dual(i, torch.rand_like(i))
                        if isinstance(i, torch.Tensor)
                        else i
                    )
                    for i in inputs
                ]
                # Forward AD
                output = convolution(*dual_inputs)
                # Forward over reverse AD
                grad_output_d = fwAD.make_dual(
                    torch.rand_like(output), torch.rand_like(output)
                )
                if has_bias:
                    torch.autograd.grad(output, [x, weight, bias], grad_output_d)
                else:
                    torch.autograd.grad(output, [x, weight], grad_output_d)

        # Convert to float64 for gradcheck.
        x = x.to(torch.float64).detach().requires_grad_(True)
        weight = weight.to(torch.float64).detach().requires_grad_(True)
        if bias is not None:
            bias = bias.to(torch.float64).detach().requires_grad_(True)
        inputs = [
            x,
            weight,
            bias,
            stride,
            padding,
            dilation,
            transposed,
            output_padding,
            groups,
        ]

        # Set some backend-specific validation settings.
        gradcheck_nondet_tol = 0.0
        if torch.backends.cudnn.is_available():
            # cuDNN introduces non-determinism
            gradcheck_nondet_tol = GRADCHECK_NONDET_TOL

        self.assertTrue(gradcheck(convolution, inputs, nondet_tol=gradcheck_nondet_tol))

        # double backward doesn't support bias gradients
        if bias is not None:
            bias.requires_grad_(False)
        self.assertTrue(
            gradgradcheck(convolution, inputs, nondet_tol=gradcheck_nondet_tol)
        )

    @onlyCPU
    def test_conv_contiguous_for_oneDNN(self):
        # See https://github.com/pytorch/pytorch/issues/80837.
        for dtype in [torch.float, torch.bfloat16, torch.half]:
            conv = nn.Conv2d(
                1,
                128,
                kernel_size=(5, 2),
                stride=(2, 1),
                padding=(0, 1),
                dilation=(1, 1),
                groups=1,
                bias=True,
                padding_mode="zeros",
            ).to(dtype=dtype)

            x = torch.rand([1, 2, 321, 201, 1]).to(dtype=dtype)
            x = torch.transpose(x, 1, 4)
            x2 = x[..., 0]
            if torch.backends.mkldnn.is_available():
                y = conv(x2)
                # Disable MKLDNN explicitly
                with torch.backends.mkldnn.flags(enabled=False):
                    y_ = conv(x2)
                    self.assertEqual(y, y_)

    @onlyCPU
    def test_conv_ic1_channels_last_for_oneDNN(self):
        # See https://github.com/pytorch/pytorch/issues/82060, N > 1 will call in OneDNN path.
        for dtype in [torch.float, torch.bfloat16, torch.half]:
            conv = torch.nn.Conv2d(
                1, 64, kernel_size=(3, 3), padding=(1, 1), bias=False
            )
            conv = conv.to(memory_format=torch.channels_last).to(dtype=dtype)
            x = torch.rand(2, 1, 100, 100).to(dtype=dtype)
            if torch.backends.mkldnn.is_available():
                y = conv(x)
                # Disable MKLDNN explicitly
                with torch.backends.mkldnn.flags(enabled=False):
                    y_ = conv(x)
                    self.assertEqual(y, y_)

    @dtypes(torch.float, torch.cfloat)
    def test_conv_empty_channel(self, device, dtype):
        in_channels = 0
        mod = torch.nn.Conv1d(in_channels, 8, 2, stride=2, dtype=dtype).to(device)
        inp = torch.randn(2, 0, 15, device=device, dtype=dtype)
        _test_module_empty_input(self, mod, inp, check_size=False)

        with self.assertRaisesRegex(RuntimeError, "Given groups=1, weight"):
            inp = torch.randn(2, 1, 0, device=device, dtype=dtype)
            mod(inp)

        mod = torch.nn.Conv2d(in_channels, 33, 3, stride=2, dtype=dtype).to(device)
        inp = torch.randn(2, 0, 50, 100, device=device, dtype=dtype)
        _test_module_empty_input(self, mod, inp, check_size=False)

        with self.assertRaisesRegex(RuntimeError, "Given groups=1, weight"):
            inp = torch.randn(2, 1, 40, 0, device=device, dtype=dtype)
            mod(inp)

        mod = torch.nn.Conv3d(in_channels, 33, 3, stride=2, dtype=dtype).to(device)
        inp = torch.randn(2, 0, 50, 20, 40, device=device, dtype=dtype)
        _test_module_empty_input(self, mod, inp, check_size=False)

        with self.assertRaisesRegex(RuntimeError, "Given groups=1, weight"):
            inp = torch.randn(2, 1, 50, 0, 40, device=device, dtype=dtype)
            mod(inp)

    def test_group_conv_empty(self, device):
        mod = torch.nn.Conv2d(4, 4, stride=2, kernel_size=3, padding=1, groups=4).to(
            device
        )
        inp = torch.randn(0, 4, 4, 4, device=device)
        _test_module_empty_input(self, mod, inp, check_size=False)
        if self.device_type == "cuda" and self.has_cudnn():
            with torch.backends.cudnn.flags(enabled=False):
                _test_module_empty_input(self, mod, inp, check_size=False)

    def test_group_convTranspose_empty(self, device):
        mod = torch.nn.ConvTranspose2d(
            4, 4, stride=2, kernel_size=3, padding=1, groups=4
        ).to(device)
        inp = torch.randn(0, 4, 4, 4, device=device)
        _test_module_empty_input(self, mod, inp, check_size=False)
        if self.device_type == "cuda" and self.has_cudnn():
            with torch.backends.cudnn.flags(enabled=False):
                _test_module_empty_input(self, mod, inp, check_size=False)

    def test_convTranspose_empty(self, device):
        mod = torch.nn.ConvTranspose2d(4, 4, stride=2, kernel_size=3, padding=1).to(
            device
        )
        inp = torch.randn(0, 4, 4, 4, device=device)
        _test_module_empty_input(self, mod, inp, check_size=False)
        if self.device_type == "cuda" and self.has_cudnn():
            with torch.backends.cudnn.flags(enabled=False):
                _test_module_empty_input(self, mod, inp, check_size=False)

    @onlyCUDA
    @largeTensorTest("12GB")
    @serialTest()
    def test_conv_large_nosplit(self, device):
        # Here we just test the convolution correctly route to the fallback implementation
        # that is, it does not crash. The correctness of fallback implementation should be
        # covered in other tests
        dtype = torch.half if self.device_type == "cuda" else torch.float
        conv1 = nn.Conv2d(2, 2, 8, 8).to(device).to(dtype)
        input_large = torch.randn(1, 2, 1024, 1024 * 1024, dtype=dtype, device=device)
        conv1(input_large)
        conv2 = torch.nn.Conv2d(1, 1024, 1, 1).to(device).to(dtype)
        input_large = torch.randn(1, 1, 2048, 1024, dtype=dtype, device=device)
        conv2(input_large)

    def test_conv_noncontig_weights(self, device):
        for dim in (1, 2, 3):
            for grouped in (False, True):
                nc = 3
                groups = 3 if grouped else 1
                w = torch.randn([3] * dim, device=device)
                w = w.expand([nc, int(nc / groups)] + list(w.shape))
                w = w.detach().requires_grad_()
                x = torch.randn(
                    [1, nc] + ([5] * dim), device=device, requires_grad=True
                )
                y = getattr(F, f"conv{dim}d")(x, w, groups=groups)
                y.sum().backward()
                y = getattr(F, f"conv_transpose{dim}d")(x, w, groups=groups)
                y.sum().backward()

    def test_conv_noncontig_weights_and_bias(self, device):
        # need floats to exercise https://github.com/pytorch/pytorch/issues/16018
        for bias in [True, False]:
            conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=bias).to(
                device, torch.float
            )

            input_nc = torch.randn(
                (1, 3, 224, 224, 2), device=device, dtype=torch.float
            )[:, :, :, :, 1]
            input_c = input_nc.contiguous()

            weight_nc = torch.randn((64, 3, 7, 7, 2), device=device, dtype=torch.float)[
                :, :, :, :, 1
            ]
            conv1.weight = nn.Parameter(weight_nc)
            weight_c = conv1.weight.contiguous()

            if bias:
                bias_nc = torch.randn((64, 2), device=device, dtype=torch.float)[:, 1]
                conv1.bias = nn.Parameter(bias_nc)
                bias_c = conv1.bias.contiguous()

            out1 = conv1(input_nc)
            conv1.weight = nn.Parameter(weight_c)
            if bias:
                conv1.bias = nn.Parameter(bias_c)
            out2 = conv1(input_c)
            self.assertEqual(out1, out2)

    @onlyCUDA
    @largeTensorTest("12GB")
    @serialTest()
    def test_conv_transposed_large(self, device):
        dtype = torch.half if self.device_type == "cuda" else torch.float
        conv = nn.ConvTranspose2d(1, 1, 1, 1, bias=False).to(device).to(dtype)
        input_large = torch.randn(4096, 1, 512, 1024, dtype=dtype, device=device)
        # forward
        ret = conv(input_large)
        maxdiff0 = (
            (ret.narrow(0, 0, 1024) - conv(input_large.narrow(0, 0, 1024)))
            .abs_()
            .max()
            .item()
        )
        maxdiff1 = (
            (ret.narrow(0, 1024, 1024) - conv(input_large.narrow(0, 1024, 1024)))
            .abs_()
            .max()
            .item()
        )
        maxdiff2 = (
            (ret.narrow(0, 2048, 1024) - conv(input_large.narrow(0, 2048, 1024)))
            .abs_()
            .max()
            .item()
        )
        maxdiff3 = (
            (ret.narrow(0, 3072, 1024) - conv(input_large.narrow(0, 3072, 1024)))
            .abs_()
            .max()
            .item()
        )
        if self.device_type == "cuda":
            # cuDNN may use algorithms such as FFT that don't guarantee a diff of 0
            self.assertEqual(maxdiff0, 0, atol=2e-3, rtol=1e-5)
            self.assertEqual(maxdiff1, 0, atol=2e-3, rtol=1e-5)
            self.assertEqual(maxdiff2, 0, atol=2e-3, rtol=1e-5)
            self.assertEqual(maxdiff3, 0, atol=2e-3, rtol=1e-5)
        else:
            self.assertEqual(maxdiff0, 0)
            self.assertEqual(maxdiff1, 0)
            self.assertEqual(maxdiff2, 0)
            self.assertEqual(maxdiff3, 0)

    @onlyCUDA
    @largeTensorTest("12GB")
    @serialTest()
    def test_conv_large(self, device):
        dtype = torch.half if self.device_type == "cuda" else torch.float
        conv = nn.Conv2d(2, 2, 8, 8, bias=False).to(device).to(dtype)
        input_large = torch.randn(4097, 2, 512, 512, dtype=dtype, device=device)
        # forward
        ret = conv(input_large)
        self.assertEqual(ret[:2048], conv(input_large[:2048]))
        self.assertEqual(ret[2048:4096], conv(input_large[2048:4096]))
        self.assertEqual(ret[4096:], conv(input_large[4096:]))

        # backward
        conv.zero_grad()
        # When computing the backward, we are using the `max(dim=1)`` to create
        # some sparsity. Without this sparsity, the rounding error would be
        # too large (as large as 1e-5) to satisfy the creterion (1e-6) of `assertEqual`
        ret.view(4097, -1).max(dim=1).values.sum().backward()
        del ret
        grad1 = conv.weight.grad.detach().clone()
        conv.zero_grad()
        conv(input_large[:2048]).view(2048, -1).max(dim=1).values.sum().backward()
        conv(input_large[2048:4096]).view(2048, -1).max(dim=1).values.sum().backward()
        conv(input_large[4096:]).view(1, -1).max(dim=1).values.sum().backward()
        grad2 = conv.weight.grad.detach().clone()
        # gradients are at the order of hundreds, we need to scale it to
        # the order of one so that we can compare
        scale = 1 / grad2.abs().mean()
        grad1 = grad1 * scale
        grad2 = grad2 * scale
        self.assertEqual(grad1, grad2, atol=5e-2, rtol=5e-3)

    @onlyCUDA
    @largeTensorTest("20GB", "cpu")
    @largeTensorTest("60GB", "cuda")
    @serialTest()
    def test_conv_large_batch_1(self, device):
        in_channels = 514
        dim = 2048
        out_channels = 1
        kernel_size = 3
        stride = 1
        padding = 1

        input_tensor = torch.ones(1, in_channels, dim, dim).cuda().half()
        model = (
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
            .cuda()
            .half()
        )
        output = model(input_tensor)
        _model_cpu = model.cpu().float()
        output_cpu = model(input_tensor.float().cpu())
        self.assertEqual(output.cpu().float(), output_cpu, atol=1e-3, rtol=1e-3)

    @onlyCUDA
    @skipCUDAIfNoCudnn
    def test_contig_wrong_stride_cudnn(self, device):
        # x has to have batch_size 1 to test contiguous checks
        x = torch.randn(1, 16, 5, 5, device=device)
        stride = list(x.stride())
        stride[0] = 20
        # change the stride in dimension 0. the tensor is still contiguous because size[0] is 1
        x.set_(x.storage(), 0, x.size(), stride)
        self.assertTrue(x.is_contiguous())
        F.conv_transpose2d(x, torch.randn(16, 1, 1, 1, device=device))
        F.conv2d(x, torch.randn(1, 16, 1, 1, device=device))

    @skipIfRocmArch(MI300_ARCH)
    @onlyCUDA
    @tf32_on_and_off(0.005)
    def test_Conv2d_size_1_kernel(self, device):
        x_cpu = torch.randn(2, 3, 5, 5)
        conv_cpu = torch.nn.Conv2d(3, 3, kernel_size=1)
        y_cpu = conv_cpu(x_cpu)
        y = torch.rand_like(y_cpu)
        y_cpu.backward(y)

        with cudnn.flags(enabled=False):
            conv_cuda = torch.nn.Conv2d(3, 3, kernel_size=1).to(device)
            conv_cuda.bias.data.copy_(conv_cpu.bias.data)
            conv_cuda.weight.data.copy_(conv_cpu.weight.data)
            y_cuda = conv_cuda(x_cpu.to(device))
            y_cuda.backward(y.to(device))

        self.assertEqual(y_cpu, y_cuda, atol=1e-5, rtol=0, exact_device=False)
        self.assertEqual(
            conv_cpu.bias.grad.data,
            conv_cuda.bias.grad.data,
            atol=1e-5,
            rtol=0,
            exact_device=False,
        )
        self.assertEqual(
            conv_cpu.weight.grad.data,
            conv_cuda.weight.grad.data,
            atol=1e-5,
            rtol=0,
            exact_device=False,
        )

    @skipIfRocmArch(MI300_ARCH)
    @onlyCUDA
    @tf32_on_and_off(0.005)
    def test_ConvTranspose2d_size_1_kernel(self, device):
        x_cpu = torch.randn(2, 3, 5, 5)
        conv_cpu = torch.nn.ConvTranspose2d(3, 3, kernel_size=1)
        y_cpu = conv_cpu(x_cpu)
        y = torch.rand_like(y_cpu)
        y_cpu.backward(y)

        with cudnn.flags(enabled=False):
            conv_cuda = torch.nn.ConvTranspose2d(3, 3, kernel_size=1).to(device)
            conv_cuda.bias.data.copy_(conv_cpu.bias.data)
            conv_cuda.weight.data.copy_(conv_cpu.weight.data)
            y_cuda = conv_cuda(x_cpu.to(device))
            y_cuda.backward(y.to(device))

        self.assertEqual(y_cpu, y_cuda, atol=1e-5, rtol=0, exact_device=False)
        self.assertEqual(
            conv_cpu.bias.grad.data,
            conv_cuda.bias.grad.data,
            atol=1e-5,
            rtol=0,
            exact_device=False,
        )
        self.assertEqual(
            conv_cpu.weight.grad.data,
            conv_cuda.weight.grad.data,
            atol=1e-5,
            rtol=0,
            exact_device=False,
        )

    @onlyCUDA
    def test_ConvTranspose3d_size_1_kernel(self, device):
        with set_default_dtype(torch.double):
            x_cpu = torch.randn(2, 3, 3, 5, 5)
            conv_cpu = torch.nn.ConvTranspose3d(3, 3, kernel_size=1)
            y_cpu = conv_cpu(x_cpu)
            y = torch.rand_like(y_cpu)
            y_cpu.backward(y)

            with cudnn.flags(enabled=False):
                conv_cuda = torch.nn.ConvTranspose3d(3, 3, kernel_size=1).to(device)
                conv_cuda.bias.data.copy_(conv_cpu.bias.data)
                conv_cuda.weight.data.copy_(conv_cpu.weight.data)
                y_cuda = conv_cuda(x_cpu.to(device))
                y_cuda.backward(y.to(device))

            self.assertEqual(y_cpu, y_cuda, atol=1e-5, rtol=0, exact_device=False)
            self.assertEqual(
                conv_cpu.bias.grad.data,
                conv_cuda.bias.grad.data,
                atol=1e-5,
                rtol=0,
                exact_device=False,
            )
            self.assertEqual(
                conv_cpu.weight.grad.data,
                conv_cuda.weight.grad.data,
                atol=1e-5,
                rtol=0,
                exact_device=False,
            )

    @dtypesIfCUDA(
        *floating_types_and(torch.half, *[torch.bfloat16] if AMPERE_OR_ROCM else [])
    )
    @dtypes(torch.float)
    @torch.backends.cudnn.flags(enabled=True, deterministic=True, benchmark=False)
    @torch.backends.miopen.flags(immediate=True)
    @tf32_on_and_off(0.001)
    def test_Conv2d_naive_groups(self, device, dtype):
        # Check that grouped convolutions matches two half convolutions
        m = nn.Conv2d(4, 4, kernel_size=3, groups=2).to(device, dtype)
        i = torch.randn(2, 4, 6, 6, device=device, dtype=dtype, requires_grad=True)
        output = m(i)
        grad_output = torch.randn(2, 4, 4, 4, device=device, dtype=dtype)
        output.backward(grad_output)

        m1 = nn.Conv2d(2, 2, kernel_size=3).to(device, dtype)
        m1.weight.data.copy_(m.weight.data[:2])
        m1.bias.data.copy_(m.bias.data[:2])
        i1 = i.data[:, :2].contiguous().requires_grad_(True)
        output1 = m1(i1)
        output1.backward(grad_output[:, :2].contiguous())

        m2 = nn.Conv2d(2, 2, kernel_size=3).to(device, dtype)
        m2.weight.data.copy_(m.weight.data[2:])
        m2.bias.data.copy_(m.bias.data[2:])
        i2 = i.data[:, 2:].contiguous().requires_grad_(True)
        output2 = m2(i2)
        output2.backward(grad_output[:, 2:].contiguous())

        self.assertEqual(output, torch.cat([output1, output2], 1))
        self.assertEqual(
            i.grad.data,
            torch.cat([i1.grad.data, i2.grad.data], 1),
            atol=dtype2prec_DONTUSE[dtype],
            rtol=0,
        )
        self.assertEqual(
            m.bias.grad.data,
            torch.cat([m1.bias.grad.data, m2.bias.grad.data], 0),
            atol=dtype2prec_DONTUSE[dtype],
            rtol=0,
        )
        self.assertEqual(
            m.weight.grad.data,
            torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0),
            atol=dtype2prec_DONTUSE[dtype],
            rtol=0,
        )

    @dtypes(torch.double, torch.cdouble)
    @dtypesIfMPS(torch.float, torch.cfloat)
    @expectedFailureMPS  # https://github.com/pytorch/pytorch/issues/107214
    def test_Conv2d_backward_depthwise(self, device, dtype):
        x = torch.randn(2, 2, 4, 20, device=device, dtype=dtype, requires_grad=True)
        weight = torch.randn(2, 1, 3, 5, device=device, dtype=dtype, requires_grad=True)

        def conv2d_depthwise(x, weight):
            return torch.nn.functional.conv2d(
                x, weight, bias=None, stride=(1, 10), groups=2
            )

        for cudnn_enabled in [False, True]:
            with torch.backends.cudnn.flags(enabled=cudnn_enabled):
                torch.autograd.gradcheck(conv2d_depthwise, (x, weight))

    @onlyCPU
    @dtypes(torch.float, torch.double)
    def test_conv_thnn_nhwc(self, device, dtype):
        def helper(
            mod,
            n,
            c,
            h,
            w,
            out_channels,
            kernel_size,
            dilation,
            groups,
            input_format,
            weight_format,
        ):
            input = torch.randint(-3, 3, (n, c, h, w), dtype=dtype, device=device).to(
                memory_format=input_format
            )
            input.requires_grad_()
            conv = mod(
                c, out_channels, kernel_size, dilation=dilation, groups=groups
            ).to(device="cpu", dtype=dtype, memory_format=weight_format)
            for p in conv.parameters():
                p.data = torch.randint_like(p, -3, 3)

            ref_input = input.detach().clone().contiguous().requires_grad_()
            ref_conv = mod(
                c, out_channels, kernel_size, dilation=dilation, groups=groups
            )
            # load_state_dict will restore the stride & memory_layout on ref_conv.weight.
            ref_conv.load_state_dict(conv.state_dict())
            ref_conv = ref_conv.to(
                device="cpu", dtype=dtype, memory_format=torch.contiguous_format
            )

            out = conv(input)
            ref_out = ref_conv(ref_input)

            grad = torch.randint_like(out, -3, 3)
            ref_grad = grad.detach().clone().contiguous()

            out.backward(grad)
            ref_out.backward(ref_grad)

            self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
            self.assertTrue(ref_out.is_contiguous())
            self.assertEqual(out, ref_out, exact_dtype=False)
            self.assertEqual(conv.weight.grad, ref_conv.weight.grad, exact_dtype=False)
            self.assertEqual(conv.bias.grad, ref_conv.bias.grad, exact_dtype=False)
            self.assertEqual(input.grad, ref_input.grad, exact_dtype=False)

        with torch.backends.mkldnn.flags(enabled=False):
            formats = [
                [torch.channels_last, torch.channels_last],
                [torch.channels_last, torch.contiguous_format],
                [torch.contiguous_format, torch.channels_last],
            ]
            for input_format, weight_format in formats:
                # non-dilated conv: thnn_conv2d normal path (with im2col)
                helper(
                    nn.Conv2d,
                    2,
                    8,
                    4,
                    4,
                    out_channels=4,
                    kernel_size=3,
                    dilation=1,
                    groups=1,
                    input_format=input_format,
                    weight_format=weight_format,
                )
                helper(
                    nn.Conv2d,
                    2,
                    8,
                    4,
                    4,
                    out_channels=8,
                    kernel_size=3,
                    dilation=1,
                    groups=8,
                    input_format=input_format,
                    weight_format=weight_format,
                )
                # test when input channels is 1 and not converted to channels last
                helper(
                    nn.Conv2d,
                    2,
                    1,
                    10,
                    10,
                    out_channels=8,
                    kernel_size=3,
                    dilation=1,
                    groups=1,
                    input_format=torch.contiguous_format,
                    weight_format=torch.channels_last,
                )
                # non-dilated conv: thnn_conv2d fast path (skip im2col)
                helper(
                    nn.Conv2d,
                    1,
                    16,
                    56,
                    56,
                    out_channels=16,
                    kernel_size=1,
                    dilation=1,
                    groups=1,
                    input_format=input_format,
                    weight_format=weight_format,
                )
                # ic == oc == 1 here, so need to stick input to CL to activate channels last
                helper(
                    nn.Conv2d,
                    1,
                    16,
                    56,
                    56,
                    out_channels=16,
                    kernel_size=1,
                    dilation=1,
                    groups=16,
                    input_format=torch.channels_last,
                    weight_format=weight_format,
                )
                # dilated conv: slow_conv_dilated2d
                helper(
                    nn.Conv2d,
                    2,
                    8,
                    11,
                    13,
                    out_channels=16,
                    kernel_size=3,
                    dilation=2,
                    groups=1,
                    input_format=input_format,
                    weight_format=weight_format,
                )
                helper(
                    nn.Conv2d,
                    2,
                    16,
                    11,
                    13,
                    out_channels=32,
                    kernel_size=3,
                    dilation=2,
                    groups=16,
                    input_format=input_format,
                    weight_format=weight_format,
                )
                # transposed-conv: slow_conv_transpose2d
                helper(
                    nn.ConvTranspose2d,
                    2,
                    8,
                    4,
                    4,
                    out_channels=4,
                    kernel_size=3,
                    dilation=1,
                    groups=1,
                    input_format=input_format,
                    weight_format=weight_format,
                )
                helper(
                    nn.ConvTranspose2d,
                    2,
                    8,
                    4,
                    4,
                    out_channels=8,
                    kernel_size=3,
                    dilation=1,
                    groups=8,
                    input_format=input_format,
                    weight_format=weight_format,
                )
                helper(
                    nn.ConvTranspose2d,
                    1,
                    16,
                    56,
                    56,
                    out_channels=16,
                    kernel_size=1,
                    dilation=1,
                    groups=1,
                    input_format=input_format,
                    weight_format=weight_format,
                )
                helper(
                    nn.ConvTranspose2d,
                    1,
                    16,
                    56,
                    56,
                    out_channels=32,
                    kernel_size=1,
                    dilation=1,
                    groups=16,
                    input_format=input_format,
                    weight_format=weight_format,
                )

    @onlyCUDA
    @dtypes(torch.half, torch.float, torch.cfloat)
    def test_conv_cudnn_nhwc(self, device, dtype):
        def helper(n, c, h, w, out_channels, kernel_size, groups):
            # randint with dtype=torch.cfloat fails with
            # RuntimeError: check_random_bounds handles only integral, floating-point and boolean types
            # must create randint and randint_like using default int64, then cast to desired
            input = torch.randint(
                -3, 3, (n, c, h, w), dtype=torch.int64, device=device
            ).to(dtype, memory_format=torch.channels_last)
            input.requires_grad_()
            conv = nn.Conv2d(c, out_channels, kernel_size, groups=groups).to(
                device="cuda", dtype=dtype, memory_format=torch.channels_last
            )
            for p in conv.parameters():
                p.data = torch.randint_like(p, -3, 3, dtype=torch.int64).to(p.dtype)

            # use FP64 channels-first conv as reference
            ref_input = input.detach().clone().contiguous().double().requires_grad_()
            ref_conv = nn.Conv2d(c, out_channels, kernel_size, groups=groups)
            # load_state_dict will restore the stride & memory_layout on ref_conv.weight.
            ref_conv.load_state_dict(conv.state_dict())
            ref_conv = ref_conv.to(
                device="cuda", dtype=torch.double, memory_format=torch.contiguous_format
            )

            out = conv(input)
            ref_out = ref_conv(ref_input)

            grad = torch.randint_like(out, -3, 3, dtype=torch.int64).to(out.dtype)
            ref_grad = grad.detach().clone().double().contiguous()

            out.backward(grad)
            ref_out.backward(ref_grad)

            self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
            self.assertTrue(input.grad.is_contiguous(memory_format=torch.channels_last))
            self.assertTrue(
                conv.weight.grad.is_contiguous(memory_format=torch.channels_last)
            )

            self.assertTrue(ref_out.is_contiguous())
            self.assertTrue(ref_input.grad.is_contiguous())
            self.assertTrue(ref_conv.weight.grad.is_contiguous())

            self.assertEqual(out, ref_out, exact_dtype=False)
            self.assertEqual(conv.weight.grad, ref_conv.weight.grad, exact_dtype=False)
            self.assertEqual(conv.bias.grad, ref_conv.bias.grad, exact_dtype=False)
            self.assertEqual(input.grad, ref_input.grad, exact_dtype=False)

        helper(2, 8, 4, 4, out_channels=4, kernel_size=3, groups=1)
        helper(2, 8, 4, 4, out_channels=8, kernel_size=3, groups=8)
        helper(1, 16, 56, 56, out_channels=16, kernel_size=3, groups=1)
        helper(1, 16, 56, 56, out_channels=16, kernel_size=3, groups=16)

    @onlyCUDA
    @dtypes(torch.half, torch.float)
    def test_conv_cudnn_ndhwc(self, device, dtype):
        def helper(n, c, d, h, w, out_channels, kernel_size, groups):
            input = torch.randint(
                -2, 2, (n, c, d, h, w), dtype=dtype, device=device
            ).to(memory_format=torch.channels_last_3d)
            input.requires_grad_()
            conv = nn.Conv3d(c, out_channels, kernel_size, groups=groups).to(
                device="cuda", dtype=dtype, memory_format=torch.channels_last_3d
            )
            for p in conv.parameters():
                p.data = torch.randint_like(p, -2, 2)

            # use FP64 channels-first conv as reference
            ref_input = input.detach().clone().contiguous().double().requires_grad_()
            ref_conv = nn.Conv3d(c, out_channels, kernel_size, groups=groups)
            # load_state_dict will restore the stride & memory_layout on ref_conv.weight.
            ref_conv.load_state_dict(conv.state_dict())
            ref_conv = ref_conv.to(
                device="cuda", dtype=torch.double, memory_format=torch.contiguous_format
            )

            out = conv(input)
            ref_out = ref_conv(ref_input)

            grad = torch.randint_like(out, -2, 2)
            ref_grad = grad.detach().clone().double().contiguous()

            out.backward(grad)
            ref_out.backward(ref_grad)

            self.assertTrue(out.is_contiguous(memory_format=torch.channels_last_3d))
            self.assertTrue(
                input.grad.is_contiguous(memory_format=torch.channels_last_3d)
            )
            self.assertTrue(
                conv.weight.grad.is_contiguous(memory_format=torch.channels_last_3d)
            )

            self.assertTrue(ref_out.is_contiguous())
            self.assertTrue(ref_input.grad.is_contiguous())
            self.assertTrue(ref_conv.weight.grad.is_contiguous())

            self.assertEqual(out, ref_out, exact_dtype=False)
            self.assertEqual(conv.weight.grad, ref_conv.weight.grad, exact_dtype=False)
            self.assertEqual(conv.bias.grad, ref_conv.bias.grad, exact_dtype=False)
            self.assertEqual(input.grad, ref_input.grad, exact_dtype=False)

        helper(2, 8, 4, 4, 4, out_channels=4, kernel_size=3, groups=1)
        helper(2, 8, 4, 4, 4, out_channels=8, kernel_size=3, groups=8)
        helper(1, 16, 18, 18, 18, out_channels=16, kernel_size=3, groups=1)
        helper(1, 16, 18, 18, 18, out_channels=16, kernel_size=3, groups=16)

    def _run_conv(
        self,
        layer,
        device,
        inp,
        grad,
        ref_conv,
        ref_input,
        ref_out,
        input_format,
        weight_format,
        grad_format,
        output_format,
    ):
        conv = (
            layer(inp.size(1), grad.size(1), ref_conv.weight.size(2)).float().to(device)
        )
        # load_state_dict will restore the stride & memory_layout on ref_conv.weight.
        conv.load_state_dict(ref_conv.state_dict())
        weight_data = (
            conv.weight.detach().clone().contiguous(memory_format=weight_format)
        )
        conv.weight.data = weight_data.resize_(
            weight_data.size(), memory_format=weight_format
        )
        input = inp.clone().contiguous(memory_format=input_format)
        input.resize_(input.size(), memory_format=input_format)
        input = input.requires_grad_()
        grad = grad.contiguous(memory_format=grad_format)
        grad.resize_(grad.size(), memory_format=grad_format)
        out = conv(input)
        out.backward(grad)
        self.assertTrue(out.is_contiguous(memory_format=output_format))
        self.assertEqual(out, ref_out)
        self.assertEqual(conv.weight.grad, ref_conv.weight.grad)
        self.assertEqual(conv.bias.grad, ref_conv.bias.grad)
        self.assertEqual(input.grad, ref_input.grad)

    def _test_conv_cudnn_nhwc_nchw(self, layer, n, c, h, w, k, filter_size, device):
        data = torch.randint(1, 10, (n, c, h, w), dtype=torch.float32, device=device)
        ref_input = data.clone().contiguous().requires_grad_(True)
        ref_conv = layer(c, k, filter_size).float().to(device)
        ref_out = ref_conv(ref_input)
        grad = torch.randint(1, 10, ref_out.size(), dtype=torch.float32, device="cuda")
        ref_out.backward(grad)

        for w_f in [torch.contiguous_format, torch.channels_last]:
            for g_f in [torch.contiguous_format, torch.channels_last]:
                for input_format in [torch.contiguous_format, torch.channels_last]:
                    output_format = torch.contiguous_format
                    # Older versions of CudNN have Channels Last support disabled
                    if torch.backends.cudnn.version() >= 7603:
                        if input_format == torch.channels_last:
                            output_format = torch.channels_last
                        # This is because we have N111 weight that cannot handle
                        # the ambiguous memory_format
                        if w_f == torch.channels_last:
                            if layer is nn.Conv2d and filter_size * c != 1:
                                output_format = torch.channels_last
                            if layer is nn.ConvTranspose2d and filter_size * k != 1:
                                output_format = torch.channels_last
                    self._run_conv(
                        layer,
                        device,
                        data,
                        grad,
                        ref_conv,
                        ref_input,
                        ref_out,
                        input_format,
                        w_f,
                        g_f,
                        output_format,
                    )

    @onlyCUDA
    @tf32_on_and_off(0.05)
    def test_conv_cudnn_mismatch_memory_format(self, device):
        configs = [
            [4, 2, 8, 8, 4, 2],
            [4, 1, 8, 8, 4, 2],
            [1, 1, 8, 8, 4, 2],
            [4, 2, 2, 8, 4, 1],
            [4, 2, 1, 8, 4, 1],
            [4, 2, 8, 8, 4, 1],
            [4, 1, 8, 8, 4, 1],
        ]
        for n, c, h, w, k, filter_size in configs:
            self._test_conv_cudnn_nhwc_nchw(
                nn.Conv2d, n, c, h, w, k, filter_size, device
            )
            self._test_conv_cudnn_nhwc_nchw(
                nn.ConvTranspose2d, n, c, h, w, k, filter_size, device
            )

    @onlyCUDA
    @skipCUDAIfNoCudnn
    @dtypes(torch.float, torch.double, torch.float16, torch.bfloat16)
    def test_conv_cudnn_nhwc_support(self, device, dtype):
        input = torch.randn(
            (1, 16, 1, 1), dtype=dtype, device="cuda", requires_grad=True
        )
        weight = torch.randn(
            (8, 16, 3, 3), dtype=dtype, device="cuda", requires_grad=True
        )
        weight = weight.to(memory_format=torch.channels_last)
        o = torch.conv2d(input, weight, None, (2, 1), (1, 1), (1, 1), 1)
        self.assertTrue(o.is_contiguous(memory_format=torch.channels_last))
        o.sum().backward()

    # Test that faster algorithms used for inference produce the same results
    # Validates depthwise3x3 bug reported in https://github.com/pytorch/pytorch/issues/60176
    @onlyCPU
    @dtypes(torch.float)
    def test_conv2d_no_grad(self, device, dtype):
        for batch in [1, 2, 3]:
            for groups in [1, 2, 4]:
                input = torch.rand(batch, groups, 8, 8, dtype=dtype, device=device)
                m = nn.Conv2d(
                    groups,
                    8,
                    kernel_size=(3, 3),
                    groups=groups,
                    dtype=dtype,
                    device=device,
                )
                with torch.no_grad():
                    output_ng = m(input)
                output = m(input)
                self.assertEqual(output, output_ng, rtol=1e-2, atol=1e-5)

    @onlyCUDA
    @skipCUDAIfNoCudnn
    @dtypes(torch.float, torch.float16)
    @torch.backends.cudnn.flags(enabled=True, deterministic=True, benchmark=False)
    @precisionOverride({torch.half: 0.002, torch.float: 1e-4})
    def test_cudnn_convolution_relu(self, device, dtype):
        for batch, groups, image_size, kernel_size, memory_format in product(
            (1, 2, 3),
            (1, 2, 4),
            ((1, 1), (8, 8)),
            ((1, 1), (3, 3)),
            (torch.channels_last, torch.contiguous_format),
        ):
            if image_size[0] < kernel_size[0]:
                continue
            inp = torch.rand(batch, groups, *image_size, dtype=dtype, device=device)
            w = torch.randn(8, groups, *kernel_size, dtype=dtype, device=device)
            inp = inp.to(memory_format=memory_format)
            w = w.to(memory_format=memory_format)
            conv2d_out = torch.conv2d(inp, w, None, (1, 1), (0, 0), (1, 1), 1)
            if torch.version.hip:
                cudnn_out = torch.miopen_convolution_relu(
                    inp, w, None, (1, 1), (0, 0), (1, 1), 1
                )
            else:
                cudnn_out = torch.cudnn_convolution_relu(
                    inp, w, None, (1, 1), (0, 0), (1, 1), 1
                )
            self.assertTrue(cudnn_out.is_contiguous(memory_format=memory_format))
            if torch.cuda.is_tf32_supported() and dtype == torch.float:
                self.assertEqual(conv2d_out.relu(), cudnn_out, atol=4e-3, rtol=0.006)
            else:
                self.assertEqual(conv2d_out.relu(), cudnn_out)

    @onlyCUDA
    @skipCUDAIfNoCudnn
    @dtypes(torch.float, torch.float16)
    @torch.backends.cudnn.flags(enabled=True, deterministic=True, benchmark=False)
    @precisionOverride({torch.half: 0.002, torch.float: 1e-4})
    def test_cudnn_convolution_add_relu(self, device, dtype):
        for batch, groups, image_size, kernel_size, memory_format in product(
            (1, 2, 3),
            (1, 2, 4),
            ((1, 1), (8, 8)),
            ((1, 1), (3, 3)),
            (torch.channels_last, torch.contiguous_format),
        ):
            if image_size[0] < kernel_size[0]:
                continue
            inp = torch.rand(batch, groups, *image_size, dtype=dtype, device=device)
            w = torch.randn(8, groups, *kernel_size, dtype=dtype, device=device)
            inp = inp.to(memory_format=memory_format)
            w = w.to(memory_format=memory_format)
            conv2d_out = torch.conv2d(inp, w, None, (1, 1), (0, 0), (1, 1), 1)
            alpha = 2.0
            z = torch.randn_like(conv2d_out)
            z = z.to(memory_format=memory_format)
            if torch.version.hip:
                cudnn_out = torch.miopen_convolution_add_relu(
                    inp, w, z, alpha, None, (1, 1), (0, 0), (1, 1), 1
                )
            else:
                cudnn_out = torch.cudnn_convolution_add_relu(
                    inp, w, z, alpha, None, (1, 1), (0, 0), (1, 1), 1
                )

            self.assertTrue(cudnn_out.is_contiguous(memory_format=memory_format))
            if torch.cuda.is_tf32_supported() and dtype == torch.float:
                self.assertEqual(
                    F.relu(conv2d_out + alpha * z), cudnn_out, atol=2e-3, rtol=0.006
                )
            else:
                self.assertEqual(F.relu(conv2d_out + alpha * z), cudnn_out)

    @onlyCUDA
    def test_convert_conv2d_weight_memory_format(self, device):
        input = torch.randint(1, 10, (2, 8, 4, 4), dtype=torch.float32, device=device)
        model = nn.Sequential(nn.Conv2d(8, 4, 3), nn.BatchNorm2d(4)).to(device).float()
        for memory_format in [torch.channels_last, torch.contiguous_format]:
            model = nn.utils.convert_conv2d_weight_memory_format(model, memory_format)
            out = model(input)
            self.assertTrue(out.is_contiguous(memory_format=memory_format))

        model = (
            nn.Sequential(nn.ConvTranspose2d(8, 4, 3), nn.BatchNorm2d(4))
            .to(device)
            .float()
        )
        for memory_format in [torch.channels_last, torch.contiguous_format]:
            model = nn.utils.convert_conv2d_weight_memory_format(model, memory_format)
            out = model(input)
            self.assertTrue(out.is_contiguous(memory_format=memory_format))

    @onlyCUDA
    def test_convert_conv3d_weight_memory_format(self, device):
        input = torch.randint(
            1, 10, (2, 8, 4, 4, 4), dtype=torch.float32, device=device
        )
        model = (
            nn.Sequential(nn.ConvTranspose3d(8, 4, 3), nn.BatchNorm3d(4))
            .to(device)
            .float()
        )
        for memory_format in [torch.channels_last_3d, torch.contiguous_format]:
            model = nn.utils.convert_conv3d_weight_memory_format(model, memory_format)
            out = model(input)
            self.assertTrue(out.is_contiguous(memory_format=memory_format))

    def test_conv_double_backward_strided_with_3D_input_and_weight(self, device):
        # Test that _convolution_double_backward() outputs the correct grad shapes
        # for 3D input / weight when stride > 1. This is an ad-hoc regression test for a
        # specific case that was uncovered during the convolution consolidation effort.
        # The test can be safely deleted if _convolution_double_backward() is removed.

        input = torch.randn(2, 3, 6, device=device)
        weight = torch.randn(3, 3, 3, device=device)
        bias = torch.randn(3, device=device)
        stride = (2,)
        padding = (1,)
        dilation = (1,)
        transposed = False
        output_padding = (0,)
        groups = 1
        output = torch.ops.aten.convolution(
            input,
            weight,
            bias,
            stride,
            padding,
            dilation,
            transposed,
            output_padding,
            groups,
        )

        ggI = torch.randn(input.shape, device=device)
        ggW = torch.randn(weight.shape, device=device)
        ggB = torch.randn(bias.shape, device=device)
        gO = torch.randn(output.shape, device=device)
        output_mask = [True, True, True]
        (
            grad_grad_output,
            grad_input,
            grad_weight,
        ) = torch.ops.aten._convolution_double_backward(
            ggI,
            ggW,
            ggB,
            gO,
            weight,
            input,
            stride,
            padding,
            dilation,
            transposed,
            output_padding,
            groups,
            output_mask,
        )

        # Make sure the correct shapes are computed.
        self.assertEqual(grad_grad_output.shape, gO.shape)
        self.assertEqual(grad_input.shape, input.shape)
        self.assertEqual(grad_weight.shape, weight.shape)

    @skipCUDAIfRocm
    @onlyCUDA
    @largeTensorTest("40GB")
    @largeTensorTest("24GB", "cpu")
    @serialTest()
    @tf32_on_and_off(0.005)
    def test_conv3d_64bit_indexing(self, device):
        x = torch.rand(1, 32, 512, 512, 256)
        m = torch.nn.Conv3d(32, 1, kernel_size=1, padding=0, stride=1, bias=False)
        yref = m(x)
        y = m.to(device=device)(x.to(device=device))
        self.assertEqual(yref, y)

    @skipCUDAIfRocm
    @onlyCUDA
    @largeTensorTest("48GB", "cuda")
    @serialTest()
    def test_conv3d_cudnn_broken(self, device):
        for dtype in (torch.half, torch.bfloat16):
            x = torch.rand(1, 16, 124, 1282, 722, dtype=dtype, device=device)
            m = torch.nn.Conv3d(
                16,
                16,
                kernel_size=(1, 3, 3),
                padding=0,
                stride=1,
                bias=False,
                dtype=dtype,
                device=device,
            )
            with torch.backends.cudnn.flags(enabled=False):
                yref = m(x)
            y = m(x)
            self.assertEqual(yref, y)

    @skipCUDAIfRocm
    @onlyCUDA
    @largeTensorTest("20GB")
    @largeTensorTest("64GB", "cpu")
    @serialTest()
    @xfailIf(
        _get_cudnn_version() is not None and (91000 < _get_cudnn_version() < 91500)
    )
    def test_depthwise_conv_64bit_indexing(self, device):
        x = torch.randn(1, 2, 32800, 32800, dtype=torch.half).to(
            memory_format=torch.channels_last
        )
        c = nn.Conv2d(
            2, 2, kernel_size=3, stride=1, padding=1, groups=2, dtype=torch.half
        ).to(memory_format=torch.channels_last)
        yref = c(x)
        y = c.to(device=device)(x.to(device=device))
        self.assertEqual(yref, y, atol=5e-3, rtol=1e-4)
        del y, yref

        # try a batch-splittable case
        x = x.reshape(100, 2, 3280, 3280)
        x = x.contiguous(memory_format=torch.channels_last)
        yref = c.cpu()(x)
        y = c.to(device=device)(x.to(device=device))
        self.assertEqual(yref, y, atol=5e-3, rtol=1e-4)


instantiate_device_type_tests(TestConvolutionNNDeviceType, globals(), allow_mps=True)
instantiate_parametrized_tests(TestConvolutionNN)

if __name__ == "__main__":
    run_tests()
